cherry_ingest/
lib.rs

1#![allow(clippy::should_implement_trait)]
2#![allow(clippy::field_reassign_with_default)]
3
4use std::{collections::BTreeMap, pin::Pin, sync::Arc};
5
6use anyhow::{anyhow, Context, Result};
7use arrow::record_batch::RecordBatch;
8use futures_lite::{Stream, StreamExt};
9use provider::common::{evm_query_to_generic, svm_query_to_generic};
10use serde::de::DeserializeOwned;
11
12pub mod evm;
13mod provider;
14mod rayon_async;
15pub mod svm;
16
17#[derive(Debug, Clone)]
18pub enum Query {
19    Evm(evm::Query),
20    Svm(svm::Query),
21}
22
23#[cfg(feature = "pyo3")]
24impl<'py> pyo3::FromPyObject<'py> for Query {
25    fn extract_bound(ob: &pyo3::Bound<'py, pyo3::PyAny>) -> pyo3::PyResult<Self> {
26        use pyo3::types::PyAnyMethods;
27
28        let kind = ob.getattr("kind").context("get kind attribute")?;
29        let kind: &str = kind.extract().context("kind as str")?;
30
31        let query = ob.getattr("params").context("get params attribute")?;
32
33        match kind {
34            "evm" => Ok(Self::Evm(query.extract().context("parse query")?)),
35            "svm" => Ok(Self::Svm(query.extract().context("parse query")?)),
36            _ => Err(anyhow!("unknown query kind: {}", kind).into()),
37        }
38    }
39}
40
41#[derive(Debug, Clone)]
42#[cfg_attr(feature = "pyo3", derive(pyo3::FromPyObject))]
43pub struct ProviderConfig {
44    pub kind: ProviderKind,
45    pub query: Query,
46    pub url: Option<String>,
47    pub bearer_token: Option<String>,
48    pub max_num_retries: Option<usize>,
49    pub retry_backoff_ms: Option<u64>,
50    pub retry_base_ms: Option<u64>,
51    pub retry_ceiling_ms: Option<u64>,
52    pub req_timeout_millis: Option<u64>,
53    pub stop_on_head: bool,
54    pub head_poll_interval_millis: Option<u64>,
55    pub buffer_size: Option<usize>,
56}
57
58impl ProviderConfig {
59    pub fn new(kind: ProviderKind, query: Query) -> Self {
60        Self {
61            kind,
62            query,
63            url: None,
64            bearer_token: None,
65            max_num_retries: None,
66            retry_backoff_ms: None,
67            retry_base_ms: None,
68            retry_ceiling_ms: None,
69            req_timeout_millis: None,
70            stop_on_head: false,
71            head_poll_interval_millis: None,
72            buffer_size: None,
73        }
74    }
75}
76
77#[derive(Debug, Clone, Copy)]
78pub enum ProviderKind {
79    Sqd,
80    Hypersync,
81    YellowstoneGrpc,
82}
83
84#[cfg(feature = "pyo3")]
85impl<'py> pyo3::FromPyObject<'py> for ProviderKind {
86    fn extract_bound(ob: &pyo3::Bound<'py, pyo3::PyAny>) -> pyo3::PyResult<Self> {
87        use pyo3::types::PyAnyMethods;
88
89        let out: &str = ob.extract().context("read as string")?;
90
91        match out {
92            "sqd" => Ok(Self::Sqd),
93            "hypersync" => Ok(Self::Hypersync),
94            "yellowstone_grpc" => Ok(Self::YellowstoneGrpc),
95            _ => Err(anyhow!("unknown provider kind: {}", out).into()),
96        }
97    }
98}
99
100type DataStream = Pin<Box<dyn Stream<Item = Result<BTreeMap<String, RecordBatch>>> + Send + Sync>>;
101
102fn make_req_fields<T: DeserializeOwned>(query: &cherry_query::Query) -> Result<T> {
103    let mut req_fields_query = query.clone();
104    req_fields_query
105        .add_request_and_include_fields()
106        .context("add req and include fields")?;
107
108    let fields = req_fields_query
109        .fields
110        .into_iter()
111        .map(|(k, v)| {
112            (
113                k.strip_suffix('s').unwrap().to_owned(),
114                v.into_iter()
115                    .map(|v| (v, true))
116                    .collect::<BTreeMap<String, bool>>(),
117            )
118        })
119        .collect::<BTreeMap<String, _>>();
120
121    Ok(serde_json::from_value(serde_json::to_value(&fields).unwrap()).unwrap())
122}
123
124pub async fn start_stream(mut provider_config: ProviderConfig) -> Result<DataStream> {
125    let generic_query = match &mut provider_config.query {
126        Query::Evm(query) => {
127            let generic_query = evm_query_to_generic(query);
128
129            query.fields = make_req_fields(&generic_query).context("make req fields")?;
130
131            generic_query
132        }
133        Query::Svm(query) => {
134            let generic_query = svm_query_to_generic(query);
135
136            query.fields = make_req_fields(&generic_query).context("make req fields")?;
137
138            generic_query
139        }
140    };
141    let generic_query = Arc::new(generic_query);
142
143    let stream = match provider_config.kind {
144        ProviderKind::Sqd => {
145            provider::sqd::start_stream(provider_config).context("start sqd stream")?
146        }
147        ProviderKind::Hypersync => provider::hypersync::start_stream(provider_config)
148            .await
149            .context("start hypersync stream")?,
150        ProviderKind::YellowstoneGrpc => provider::yellowstone_grpc::start_stream(provider_config)
151            .await
152            .context("start yellowstone_grpc stream")?,
153    };
154
155    let stream = stream.then(move |res| {
156        let generic_query = Arc::clone(&generic_query);
157        async {
158            rayon_async::spawn(move || {
159                res.and_then(move |data| {
160                    let data = cherry_query::run_query(&data, &generic_query)
161                        .context("run local query")?;
162                    Ok(data)
163                })
164            })
165            .await
166            .unwrap()
167        }
168    });
169
170    Ok(Box::pin(stream))
171}