use futures::StreamExt;
#[cfg(feature = "cli")]
use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
use lazy_static::lazy_static;
use reqwest::Url;
use sha2::Digest;
use sha2::Sha256;
use std::env;
use std::fs::File;
use std::io;
use std::io::BufReader;
use std::io::Read;
use std::io::Seek;
use std::io::Write;
use std::path::PathBuf;
use std::time::Duration;
use std::time::Instant;
use tracing::{info, warn};
#[derive(Clone)]
pub struct MidnightDataProvider {
pub fetch_mode: FetchMode,
pub base_url: Url,
pub output_mode: OutputMode,
pub expected_data: Vec<(&'static str, [u8; 32], &'static str)>,
pub dir: PathBuf,
}
lazy_static! {
pub static ref BASE_URL: Url = Url::parse(&std::env::var("MIDNIGHT_PARAM_SOURCE").unwrap_or("https://midnight-s3-fileshare-dev-eu-west-1.s3.eu-west-1.amazonaws.com/".to_owned())).expect("$MIDNIGHT_PARAM_SOURCE should be a valid URL");
}
pub const fn hexhash(hex: &[u8]) -> [u8; 32] {
match const_hex::const_decode_to_array(hex) {
Ok(hash) => hash,
Err(_) => panic!("hash should be correct format"),
}
}
const EXPECTED_DATA: &[(&str, [u8; 32], &str)] = &[
(
"bls_midnight_2p0",
hexhash(b"59b30b3114a34ccbbfb599376e178fb8d9b3366cae2174c2f1da20e75847f823"),
"public parameters for k=0",
),
(
"bls_midnight_2p1",
hexhash(b"bbe04fe3c70d0c138447cb086b4baddc30cb8bb2a004114bc02e6f739516280e"),
"public parameters for k=1",
),
(
"bls_midnight_2p2",
hexhash(b"80e15568fa1a0117db893239be7fa5e34a6bcc3a8c3bfa7709534b9cb88eb6c1"),
"public parameters for k=2",
),
(
"bls_midnight_2p3",
hexhash(b"4be827a6472193df80d8f08b4b25a85baef436fdd1965d89b6af89f4ec4e99e2"),
"public parameters for k=3",
),
(
"bls_midnight_2p4",
hexhash(b"232f401fad10c7ddf8828d2aa4c85c6506c5da09795998cecaeb9f75fc8f6ada"),
"public parameters for k=4",
),
(
"bls_midnight_2p5",
hexhash(b"0a1c9229f315fc1868ff25f668fb83aec4d09f4f23a706b5197c692c619d72c6"),
"public parameters for k=5",
),
(
"bls_midnight_2p6",
hexhash(b"cf2ad6be7d0fedf5bec2aaa35f6be4aca33053d74268fdf5aa54fcb2891ea6df"),
"public parameters for k=6",
),
(
"bls_midnight_2p7",
hexhash(b"e82ae890c080188355f37feaffe91372584cd810615082d9143d4dec0453fd9d"),
"public parameters for k=7",
),
(
"bls_midnight_2p8",
hexhash(b"909b707551eaaea79828e883cde6fc46ab15986c3b1d791bed462c9e2805c933"),
"public parameters for k=8",
),
(
"bls_midnight_2p9",
hexhash(b"b9009f1098bcefffec3c461ab3a5e3a17f7e5599f0f08c70fcdc55a89227bcbd"),
"public parameters for k=9",
),
(
"bls_midnight_2p10",
hexhash(b"46b2290933cbed4c378889e4ba971f1a92888331ffb09466acd4ff61a1e2cb42"),
"public parameters for k=10",
),
(
"bls_midnight_2p11",
hexhash(b"9901589d7956ff58be0d85569b2f455b77b58c3758026ffb5bbe4807000b96d1"),
"public parameters for k=11",
),
(
"bls_midnight_2p12",
hexhash(b"ef08eb3fcf62df8f72c515cffa027e681808b530cb016eea104115545ef6d5c8"),
"public parameters for k=12",
),
(
"bls_midnight_2p13",
hexhash(b"d3324910969c4cc54143b8045b649e5c3a4bd5fb7b8f85fe1b770f640ce1c803"),
"public parameters for k=13",
),
(
"bls_midnight_2p14",
hexhash(b"fc253016885ec830e97808c9ec920bb5cab5c21af590380a6cb5eb0538e2b244"),
"public parameters for k=14",
),
(
"bls_midnight_2p15",
hexhash(b"724c7c3d779148bb113c7ee9c034b2f27db16e6bdf315fde90105a9bad00b1de"),
"public parameters for k=15",
),
(
"bls_midnight_2p16",
hexhash(b"09c877216d6589b370263e18af40a030a901b41a7a7c37ef58c9901db41f05c6"),
"public parameters for k=16",
),
(
"bls_midnight_2p17",
hexhash(b"4a9ef6c7c0619aab74eede44b13e753e3ba54508a02dd3b7106a949aabb73b74"),
"public parameters for k=17",
),
(
"bls_midnight_2p18",
hexhash(b"e8436dc5d8b598f169c127c745135d889744007e6d384ff126df8d1332522f86"),
"public parameters for k=18",
),
(
"bls_midnight_2p19",
hexhash(b"8e8dc15c4362f05c912f1e770559a3945db3e58a374def416ed5d3e65ad5b10e"),
"public parameters for k=19",
),
(
"bls_midnight_2p20",
hexhash(b"1cc62978558fdc1e445cd70cfd9a86ec3c2e2151b6d74811232d37faf9133ff1"),
"public parameters for k=20",
),
(
"bls_midnight_2p21",
hexhash(b"9cf1644a87f0f027ae5fc6278f91d823a6334ff3e338a29e2f2ef57d071ed64d"),
"public parameters for k=21",
),
(
"bls_midnight_2p22",
hexhash(b"e8ad5eed936d657a0fb59d2a55ba19f81a3083bb3554ef88f464f5377e9b2c2f"),
"public parameters for k=22",
),
(
"bls_midnight_2p23",
hexhash(b"09399d05f9f50875dfdd87dc9903d40c897eaafa9ec8cbb08bace853ecc36c0c"),
"public parameters for k=23",
),
(
"bls_midnight_2p24",
hexhash(b"b0e6fa7a4ab4a79a1e6560966f267556409db44bab6d5fab3711ad6c6b623207"),
"public parameters for k=24",
),
(
"bls_midnight_2p25",
hexhash(b"3289a751c938988cd2f54154d8722d1eda2cd11593064afdde82099b24ff4a58"),
"public parameters for k=25",
),
];
impl MidnightDataProvider {
pub fn new(
fetch_mode: FetchMode,
output_mode: OutputMode,
expected_data: Vec<(&'static str, [u8; 32], &'static str)>,
) -> io::Result<Self> {
Ok(Self {
fetch_mode,
base_url: BASE_URL.clone(),
output_mode,
expected_data,
dir: env::var_os("MIDNIGHT_PP")
.map(PathBuf::from)
.or_else(|| {
env::var_os("XDG_CACHE_HOME")
.map(|p| PathBuf::from(p).join("midnight").join("zk-params"))
})
.or_else(|| {
env::var_os("HOME").map(|p| {
PathBuf::from(p)
.join(".cache")
.join("midnight")
.join("zk-params")
})
})
.ok_or_else(|| {
io::Error::new(
io::ErrorKind::NotFound,
"Could not determine $HOME, $XDG_CACHE_HOME, or $MIDNIGHT_PP",
)
})?,
})
}
fn expected_hash(&self, name: &str) -> io::Result<[u8; 32]> {
Ok(EXPECTED_DATA
.iter()
.chain(self.expected_data.iter())
.find(|(n, ..)| *n == name)
.ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!(
"artifact '{name}' is not a known managed artifact by the proof data cache."
),
)
})?
.1)
}
fn description(&self, name: &str) -> io::Result<&'static str> {
Ok(EXPECTED_DATA
.iter()
.chain(self.expected_data.iter())
.find(|(n, ..)| *n == name)
.ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!(
"artifact '{name}' is not a known managed artifact by the proof data cache."
),
)
})?
.2)
}
fn get_local(&self, name: &str) -> io::Result<Option<BufReader<File>>> {
let path = self.dir.join(name);
let expected_hash = self.expected_hash(name)?;
if !std::fs::exists(&path)? {
return Ok(None);
}
let mut file = BufReader::new(File::open(&path)?);
let mut hasher = Sha256::new();
let mut buf = [0u8; 1 << 20];
loop {
let read = file.read(&mut buf)?;
if read == 0 {
break;
}
hasher.update(&buf[..read]);
}
let actual_hash = <[u8; 32]>::from(hasher.finalize());
if actual_hash != expected_hash {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"Hash mismatch in data stored at {}. Found hash {}, but expected {}. Please try removing this file to force a re-fetch. If that does not work, you may be subject to an attack.",
path.display(),
const_hex::encode(actual_hash),
const_hex::encode(expected_hash)
),
));
}
file.seek(io::SeekFrom::Start(0))?;
Ok(Some(file))
}
async fn get_or_fetch(&self, name: &str) -> io::Result<BufReader<File>> {
if let Some(data) = self.get_local(name)? {
return Ok(data);
};
let expected_hash = self.expected_hash(name)?;
let path = self.dir.join(name);
let parent = path.parent().ok_or_else(|| {
io::Error::new(
io::ErrorKind::NotFound,
format!("parent of path file {name} should exist."),
)
})?;
std::fs::create_dir_all(parent)?;
let mut file = atomic_write_file::OpenOptions::new()
.read(true)
.open(&path)?;
self.fetch_data_to(name, expected_hash, &mut file).await?;
let mut rfile = file.as_file().try_clone()?;
file.commit()?;
rfile.seek(io::SeekFrom::Start(0))?;
Ok(BufReader::new(rfile))
}
pub async fn fetch(&self, name: &str) -> io::Result<()> {
self.get_or_fetch(name).await?;
Ok(())
}
pub fn name_k(k: u8) -> String {
format!("bls_midnight_2p{k}")
}
pub async fn fetch_k(&self, k: u8) -> io::Result<()> {
self.fetch(&Self::name_k(k)).await
}
#[allow(irrefutable_let_patterns)]
async fn fetch_data_to(
&self,
name: &str,
expected_hash: [u8; 32],
f: &mut File,
) -> io::Result<()> {
const RETRIES: usize = 3;
let desc = self.description(name)?;
if let OutputMode::Log = &self.output_mode {
info!(
"Missing {desc}. Attempting to download from the host {} - this is not a trusted service, the data will be verified.",
self.base_url
);
}
#[cfg(feature = "cli")]
if let OutputMode::Cli(pb) = &self.output_mode {
pb.println(format!("Missing {desc}. Attempting to download from the host {} - this is not a trusted service, the data will be verified.", self.base_url))?;
}
let mut url = self.base_url.clone();
url.path_segments_mut()
.map_err(|()| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!(
"Base URL '{}' for proving data provider invalid",
&self.base_url
),
)
})?
.push(name);
for i in 0..RETRIES {
let retry_msg = if i == RETRIES - 1 {
"Giving up."
} else {
"Retrying..."
};
f.seek(io::SeekFrom::Start(0))?;
f.set_len(0)?;
let mut hasher = Sha256::new();
let res = match reqwest::Client::new().get(url.clone()).send().await {
Ok(res) => res,
Err(e) => {
#[cfg(feature = "cli")]
if let OutputMode::Cli(pb) = &self.output_mode {
pb.println(format!("{e}. {retry_msg}"))?;
}
warn!("{e}. {retry_msg}");
continue;
}
};
let total_size = res.content_length();
#[cfg(feature = "cli")]
let pb = if let OutputMode::Cli(multi) = &self.output_mode {
let pb = match total_size {
Some(size) => ProgressBar::new(size).with_style(
ProgressStyle::with_template(
"{msg} [{bar:.green.bold}] {bytes:.bold} / {total_bytes:.bold}",
)
.expect("Static style should parse")
.progress_chars("=> "),
),
None => ProgressBar::no_length().with_style(
ProgressStyle::with_template("{msg} {spinner:.green.bold} {bytes:.bold}")
.expect("Static style should parse"),
),
};
let pb = multi.insert(0, pb);
pb.set_message(format!("Fetching {desc}"));
Some(pb)
} else {
None
};
let mut downloaded: u64 = 0;
let mut t_last = Instant::now();
const LOG_UPDATE_FREQ: Duration = Duration::from_secs(5);
let mut stream = res.bytes_stream();
while let Some(resp) = stream.next().await {
let data = match resp {
Ok(res) => res,
Err(e) => {
#[cfg(feature = "cli")]
if let OutputMode::Cli(pb) = &self.output_mode {
pb.println(format!("{e}. {retry_msg}"))?;
}
warn!("{e}. {retry_msg}");
continue;
}
};
f.write_all(&data)?;
hasher.update(&data);
downloaded += data.len() as u64;
#[cfg(feature = "cli")]
if let Some(pb) = &pb {
pb.set_position(downloaded);
}
let t = Instant::now();
if matches!(self.output_mode, OutputMode::Log) && t - t_last > LOG_UPDATE_FREQ {
t_last = t;
match total_size {
Some(size) => {
info!("Fetching '{name}' - {downloaded} / {size} bytes downloaded")
}
None => info!("Fetching '{name}' - {downloaded} bytes downloaded"),
}
}
}
info!("Fetching {desc} - finished.");
#[cfg(feature = "cli")]
if let Some(pb) = pb {
pb.finish();
}
let hash = <[u8; 32]>::from(hasher.finalize());
if hash == expected_hash {
if let OutputMode::Log = self.output_mode {
info!("Fetching {desc} - verified correct.");
}
return Ok(());
}
warn!(
?hash,
?expected_hash,
"Fetching {desc} - hash mismatch. {retry_msg}"
);
}
Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Failed to fetch data from {url} after {RETRIES} attempts. Giving up."),
))
}
pub async fn get_file(&self, name: &str, desc: &str) -> io::Result<BufReader<File>> {
Ok(match self.fetch_mode {
FetchMode::OnDemand => self.get_or_fetch(name).await?,
FetchMode::Synchronous => self
.get_local(name)?
.ok_or_else(|| io::Error::new(io::ErrorKind::NotFound, desc))?,
})
}
}
#[derive(Debug, Copy, Clone)]
pub enum FetchMode {
OnDemand,
Synchronous,
}
#[derive(Debug, Clone)]
pub enum OutputMode {
#[cfg(feature = "cli")]
Cli(MultiProgress),
Log,
}