1#![allow(clippy::should_implement_trait)]
2#![allow(clippy::field_reassign_with_default)]
3
4use std::{collections::BTreeMap, pin::Pin, sync::Arc};
5
6use anyhow::{anyhow, Context, Result};
7use arrow::record_batch::RecordBatch;
8use futures_lite::{Stream, StreamExt};
9use provider::common::{evm_query_to_generic, svm_query_to_generic};
10use serde::de::DeserializeOwned;
11
12pub mod evm;
13mod provider;
14mod rayon_async;
15pub mod svm;
16
17#[derive(Debug, Clone)]
18pub enum Query {
19 Evm(evm::Query),
20 Svm(svm::Query),
21}
22
23#[cfg(feature = "pyo3")]
24impl<'py> pyo3::FromPyObject<'py> for Query {
25 fn extract_bound(ob: &pyo3::Bound<'py, pyo3::PyAny>) -> pyo3::PyResult<Self> {
26 use pyo3::types::PyAnyMethods;
27
28 let kind = ob.getattr("kind").context("get kind attribute")?;
29 let kind: &str = kind.extract().context("kind as str")?;
30
31 let query = ob.getattr("params").context("get params attribute")?;
32
33 match kind {
34 "evm" => Ok(Self::Evm(query.extract().context("parse query")?)),
35 "svm" => Ok(Self::Svm(query.extract().context("parse query")?)),
36 _ => Err(anyhow!("unknown query kind: {}", kind).into()),
37 }
38 }
39}
40
41#[derive(Debug, Clone)]
42#[cfg_attr(feature = "pyo3", derive(pyo3::FromPyObject))]
43pub struct ProviderConfig {
44 pub kind: ProviderKind,
45 pub url: Option<String>,
46 pub bearer_token: Option<String>,
47 pub max_num_retries: Option<usize>,
48 pub retry_backoff_ms: Option<u64>,
49 pub retry_base_ms: Option<u64>,
50 pub retry_ceiling_ms: Option<u64>,
51 pub req_timeout_millis: Option<u64>,
52 pub stop_on_head: bool,
53 pub head_poll_interval_millis: Option<u64>,
54 pub buffer_size: Option<usize>,
55}
56
57impl ProviderConfig {
58 pub fn new(kind: ProviderKind) -> Self {
59 Self {
60 kind,
61 url: None,
62 bearer_token: None,
63 max_num_retries: None,
64 retry_backoff_ms: None,
65 retry_base_ms: None,
66 retry_ceiling_ms: None,
67 req_timeout_millis: None,
68 stop_on_head: false,
69 head_poll_interval_millis: None,
70 buffer_size: None,
71 }
72 }
73}
74
75#[derive(Debug, Clone, Copy)]
76pub enum ProviderKind {
77 Sqd,
78 Hypersync,
79}
80
81#[cfg(feature = "pyo3")]
82impl<'py> pyo3::FromPyObject<'py> for ProviderKind {
83 fn extract_bound(ob: &pyo3::Bound<'py, pyo3::PyAny>) -> pyo3::PyResult<Self> {
84 use pyo3::types::PyAnyMethods;
85
86 let out: &str = ob.extract().context("read as string")?;
87
88 match out {
89 "sqd" => Ok(Self::Sqd),
90 "hypersync" => Ok(Self::Hypersync),
91 _ => Err(anyhow!("unknown provider kind: {}", out).into()),
92 }
93 }
94}
95
96type DataStream = Pin<Box<dyn Stream<Item = Result<BTreeMap<String, RecordBatch>>> + Send + Sync>>;
97
98fn make_req_fields<T: DeserializeOwned>(query: &cherry_query::Query) -> Result<T> {
99 let mut req_fields_query = query.clone();
100 req_fields_query
101 .add_request_and_include_fields()
102 .context("add req and include fields")?;
103
104 let fields = req_fields_query
105 .fields
106 .into_iter()
107 .map(|(k, v)| {
108 (
109 k.strip_suffix('s').unwrap().to_owned(),
110 v.into_iter()
111 .map(|v| (v, true))
112 .collect::<BTreeMap<String, bool>>(),
113 )
114 })
115 .collect::<BTreeMap<String, _>>();
116
117 Ok(serde_json::from_value(serde_json::to_value(&fields).unwrap()).unwrap())
118}
119
120pub async fn start_stream(provider_config: ProviderConfig, mut query: Query) -> Result<DataStream> {
121 let generic_query = match &mut query {
122 Query::Evm(evm_query) => {
123 let generic_query = evm_query_to_generic(evm_query);
124
125 evm_query.fields = make_req_fields(&generic_query).context("make req fields")?;
126
127 generic_query
128 }
129 Query::Svm(svm_query) => {
130 let generic_query = svm_query_to_generic(svm_query);
131
132 svm_query.fields = make_req_fields(&generic_query).context("make req fields")?;
133
134 generic_query
135 }
136 };
137 let generic_query = Arc::new(generic_query);
138
139 let stream = match provider_config.kind {
140 ProviderKind::Sqd => {
141 provider::sqd::start_stream(provider_config, query).context("start sqd stream")?
142 }
143 ProviderKind::Hypersync => provider::hypersync::start_stream(provider_config, query)
144 .await
145 .context("start hypersync stream")?,
146 };
147
148 let stream = stream.then(move |res| {
149 let generic_query = Arc::clone(&generic_query);
150 async {
151 rayon_async::spawn(move || {
152 res.and_then(move |data| {
153 let data = cherry_query::run_query(&data, &generic_query)
154 .context("run local query")?;
155 Ok(data)
156 })
157 })
158 .await
159 .unwrap()
160 }
161 });
162
163 Ok(Box::pin(stream))
164}
165
166#[cfg(test)]
167mod tests {
168
169 use super::*;
170 use crate::svm::*;
171 use parquet::arrow::ArrowWriter;
172 use std::fs::File;
173
174 #[tokio::test]
175 #[ignore]
176 async fn simple_svm_start_stream() {
177 let mut provider_config = ProviderConfig::new(ProviderKind::Sqd);
178 provider_config.url = Some("https://portal.sqd.dev/datasets/solana-mainnet".to_string());
179
180 let program_id = "TokenkegQfeZyiNwAJbNbGKPFXCWuBvf9Ss623VQ5DA";
181 let program_id: [u8; 32] = bs58::decode(program_id)
182 .into_vec()
183 .unwrap()
184 .try_into()
185 .unwrap();
186 let program_id = Address(program_id);
187
188 let query = crate::Query::Svm(svm::Query {
189 from_block: 329443000,
190 to_block: Some(329443000),
191 include_all_blocks: false,
192 fields: Fields {
193 instruction: InstructionFields::all(),
194 transaction: TransactionFields::default(),
195 log: LogFields::default(),
196 balance: BalanceFields::default(),
197 token_balance: TokenBalanceFields::default(),
198 reward: RewardFields::default(),
199 block: BlockFields::default(),
200 },
201 instructions: vec![
202 InstructionRequest {
204 program_id: vec![program_id],
205 discriminator: vec![Data(vec![12, 96, 49, 128, 22])],
206 ..Default::default()
207 },
208 ],
209 transactions: vec![],
210 logs: vec![],
211 balances: vec![],
212 token_balances: vec![],
213 rewards: vec![],
214 });
215 let mut stream = start_stream(provider_config, query).await.unwrap();
216 let data = stream.next().await.unwrap().unwrap();
217 for (k, v) in data.into_iter() {
218 let mut file = File::create(format!("{}.parquet", k)).unwrap();
219 let mut writer = ArrowWriter::try_new(&mut file, v.schema(), None).unwrap();
220 writer.write(&v).unwrap();
221 writer.close().unwrap();
222 }
223 }
224}