cherry_ingest/
lib.rs

1#![allow(clippy::should_implement_trait)]
2#![allow(clippy::field_reassign_with_default)]
3
4use std::{collections::BTreeMap, pin::Pin, sync::Arc};
5
6use anyhow::{anyhow, Context, Result};
7use arrow::record_batch::RecordBatch;
8use futures_lite::{Stream, StreamExt};
9use provider::common::{evm_query_to_generic, svm_query_to_generic};
10use serde::de::DeserializeOwned;
11
12pub mod evm;
13mod provider;
14mod rayon_async;
15pub mod svm;
16
17#[derive(Debug, Clone)]
18pub enum Query {
19    Evm(evm::Query),
20    Svm(svm::Query),
21}
22
23#[cfg(feature = "pyo3")]
24impl<'py> pyo3::FromPyObject<'py> for Query {
25    fn extract_bound(ob: &pyo3::Bound<'py, pyo3::PyAny>) -> pyo3::PyResult<Self> {
26        use pyo3::types::PyAnyMethods;
27
28        let kind = ob.getattr("kind").context("get kind attribute")?;
29        let kind: &str = kind.extract().context("kind as str")?;
30
31        let query = ob.getattr("params").context("get params attribute")?;
32
33        match kind {
34            "evm" => Ok(Self::Evm(query.extract().context("parse query")?)),
35            "svm" => Ok(Self::Svm(query.extract().context("parse query")?)),
36            _ => Err(anyhow!("unknown query kind: {}", kind).into()),
37        }
38    }
39}
40
41#[derive(Debug, Clone)]
42#[cfg_attr(feature = "pyo3", derive(pyo3::FromPyObject))]
43pub struct ProviderConfig {
44    pub kind: ProviderKind,
45    pub url: Option<String>,
46    pub bearer_token: Option<String>,
47    pub max_num_retries: Option<usize>,
48    pub retry_backoff_ms: Option<u64>,
49    pub retry_base_ms: Option<u64>,
50    pub retry_ceiling_ms: Option<u64>,
51    pub req_timeout_millis: Option<u64>,
52    pub stop_on_head: bool,
53    pub head_poll_interval_millis: Option<u64>,
54    pub buffer_size: Option<usize>,
55}
56
57impl ProviderConfig {
58    pub fn new(kind: ProviderKind) -> Self {
59        Self {
60            kind,
61            url: None,
62            bearer_token: None,
63            max_num_retries: None,
64            retry_backoff_ms: None,
65            retry_base_ms: None,
66            retry_ceiling_ms: None,
67            req_timeout_millis: None,
68            stop_on_head: false,
69            head_poll_interval_millis: None,
70            buffer_size: None,
71        }
72    }
73}
74
75#[derive(Debug, Clone, Copy)]
76pub enum ProviderKind {
77    Sqd,
78    Hypersync,
79    YellowstoneGrpc,
80}
81
82#[cfg(feature = "pyo3")]
83impl<'py> pyo3::FromPyObject<'py> for ProviderKind {
84    fn extract_bound(ob: &pyo3::Bound<'py, pyo3::PyAny>) -> pyo3::PyResult<Self> {
85        use pyo3::types::PyAnyMethods;
86
87        let out: &str = ob.extract().context("read as string")?;
88
89        match out {
90            "sqd" => Ok(Self::Sqd),
91            "hypersync" => Ok(Self::Hypersync),
92            "yellowstone_grpc" => Ok(Self::YellowstoneGrpc),
93            _ => Err(anyhow!("unknown provider kind: {}", out).into()),
94        }
95    }
96}
97
98type DataStream = Pin<Box<dyn Stream<Item = Result<BTreeMap<String, RecordBatch>>> + Send + Sync>>;
99
100fn make_req_fields<T: DeserializeOwned>(query: &cherry_query::Query) -> Result<T> {
101    let mut req_fields_query = query.clone();
102    req_fields_query
103        .add_request_and_include_fields()
104        .context("add req and include fields")?;
105
106    let fields = req_fields_query
107        .fields
108        .into_iter()
109        .map(|(k, v)| {
110            (
111                k.strip_suffix('s').unwrap().to_owned(),
112                v.into_iter()
113                    .map(|v| (v, true))
114                    .collect::<BTreeMap<String, bool>>(),
115            )
116        })
117        .collect::<BTreeMap<String, _>>();
118
119    Ok(serde_json::from_value(serde_json::to_value(&fields).unwrap()).unwrap())
120}
121
122pub async fn start_stream(provider_config: ProviderConfig, mut query: Query) -> Result<DataStream> {
123    let generic_query = match &mut query {
124        Query::Evm(evm_query) => {
125            let generic_query = evm_query_to_generic(evm_query);
126
127            evm_query.fields = make_req_fields(&generic_query).context("make req fields")?;
128
129            generic_query
130        }
131        Query::Svm(svm_query) => {
132            let generic_query = svm_query_to_generic(svm_query);
133
134            svm_query.fields = make_req_fields(&generic_query).context("make req fields")?;
135
136            generic_query
137        }
138    };
139    let generic_query = Arc::new(generic_query);
140
141    let stream = match provider_config.kind {
142        ProviderKind::Sqd => {
143            provider::sqd::start_stream(provider_config, query).context("start sqd stream")?
144        }
145        ProviderKind::Hypersync => provider::hypersync::start_stream(provider_config, query)
146            .await
147            .context("start hypersync stream")?,
148        ProviderKind::YellowstoneGrpc => {
149            provider::yellowstone_grpc::start_stream(provider_config, query)
150                .await
151                .context("start yellowstone_grpc stream")?
152        }
153    };
154
155    let stream = stream.then(move |res| {
156        let generic_query = Arc::clone(&generic_query);
157        async {
158            rayon_async::spawn(move || {
159                res.and_then(move |data| {
160                    let data = cherry_query::run_query(&data, &generic_query)
161                        .context("run local query")?;
162                    Ok(data)
163                })
164            })
165            .await
166            .unwrap()
167        }
168    });
169
170    Ok(Box::pin(stream))
171}
172
173#[cfg(test)]
174mod tests {
175
176    use super::*;
177    use crate::svm::*;
178    use parquet::arrow::ArrowWriter;
179    use std::fs::File;
180
181    #[tokio::test]
182    #[ignore]
183    async fn simple_svm_start_stream() {
184        let mut provider_config = ProviderConfig::new(ProviderKind::Sqd);
185        provider_config.url = Some("https://portal.sqd.dev/datasets/solana-mainnet".to_string());
186
187        let program_id = "TokenkegQfeZyiNwAJbNbGKPFXCWuBvf9Ss623VQ5DA";
188        let program_id: [u8; 32] = bs58::decode(program_id)
189            .into_vec()
190            .unwrap()
191            .try_into()
192            .unwrap();
193        let program_id = Address(program_id);
194
195        let query = crate::Query::Svm(svm::Query {
196            from_block: 329443000,
197            to_block: Some(329443000),
198            include_all_blocks: false,
199            fields: Fields {
200                instruction: InstructionFields::all(),
201                transaction: TransactionFields::default(),
202                log: LogFields::default(),
203                balance: BalanceFields::default(),
204                token_balance: TokenBalanceFields::default(),
205                reward: RewardFields::default(),
206                block: BlockFields::default(),
207            },
208            instructions: vec![
209                // InstructionRequest::default() ,
210                InstructionRequest {
211                    program_id: vec![program_id],
212                    discriminator: vec![Data(vec![12, 96, 49, 128, 22])],
213                    ..Default::default()
214                },
215            ],
216            transactions: vec![],
217            logs: vec![],
218            balances: vec![],
219            token_balances: vec![],
220            rewards: vec![],
221        });
222        let mut stream = start_stream(provider_config, query).await.unwrap();
223        let data = stream.next().await.unwrap().unwrap();
224        for (k, v) in data.into_iter() {
225            let mut file = File::create(format!("{}.parquet", k)).unwrap();
226            let mut writer = ArrowWriter::try_new(&mut file, v.schema(), None).unwrap();
227            writer.write(&v).unwrap();
228            writer.close().unwrap();
229        }
230    }
231}