1#![allow(clippy::should_implement_trait)]
2#![allow(clippy::field_reassign_with_default)]
3
4use std::{collections::BTreeMap, pin::Pin, sync::Arc};
5
6use anyhow::{anyhow, Context, Result};
7use arrow::record_batch::RecordBatch;
8use futures_lite::{Stream, StreamExt};
9use provider::common::{evm_query_to_generic, svm_query_to_generic};
10use serde::de::DeserializeOwned;
11
12pub mod evm;
13mod provider;
14mod rayon_async;
15pub mod svm;
16
17#[derive(Debug, Clone)]
18pub enum Query {
19 Evm(evm::Query),
20 Svm(svm::Query),
21}
22
23#[cfg(feature = "pyo3")]
24impl<'py> pyo3::FromPyObject<'py> for Query {
25 fn extract_bound(ob: &pyo3::Bound<'py, pyo3::PyAny>) -> pyo3::PyResult<Self> {
26 use pyo3::types::PyAnyMethods;
27
28 let kind = ob.getattr("kind").context("get kind attribute")?;
29 let kind: &str = kind.extract().context("kind as str")?;
30
31 let query = ob.getattr("params").context("get params attribute")?;
32
33 match kind {
34 "evm" => Ok(Self::Evm(query.extract().context("parse query")?)),
35 "svm" => Ok(Self::Svm(query.extract().context("parse query")?)),
36 _ => Err(anyhow!("unknown query kind: {}", kind).into()),
37 }
38 }
39}
40
41#[derive(Debug, Clone)]
42#[cfg_attr(feature = "pyo3", derive(pyo3::FromPyObject))]
43pub struct ProviderConfig {
44 pub kind: ProviderKind,
45 pub query: Query,
46 pub url: Option<String>,
47 pub bearer_token: Option<String>,
48 pub max_num_retries: Option<usize>,
49 pub retry_backoff_ms: Option<u64>,
50 pub retry_base_ms: Option<u64>,
51 pub retry_ceiling_ms: Option<u64>,
52 pub req_timeout_millis: Option<u64>,
53 pub stop_on_head: bool,
54 pub head_poll_interval_millis: Option<u64>,
55 pub buffer_size: Option<usize>,
56}
57
58impl ProviderConfig {
59 pub fn new(kind: ProviderKind, query: Query) -> Self {
60 Self {
61 kind,
62 query,
63 url: None,
64 bearer_token: None,
65 max_num_retries: None,
66 retry_backoff_ms: None,
67 retry_base_ms: None,
68 retry_ceiling_ms: None,
69 req_timeout_millis: None,
70 stop_on_head: false,
71 head_poll_interval_millis: None,
72 buffer_size: None,
73 }
74 }
75}
76
77#[derive(Debug, Clone, Copy)]
78pub enum ProviderKind {
79 Sqd,
80 Hypersync,
81 YellowstoneGrpc,
82}
83
84#[cfg(feature = "pyo3")]
85impl<'py> pyo3::FromPyObject<'py> for ProviderKind {
86 fn extract_bound(ob: &pyo3::Bound<'py, pyo3::PyAny>) -> pyo3::PyResult<Self> {
87 use pyo3::types::PyAnyMethods;
88
89 let out: &str = ob.extract().context("read as string")?;
90
91 match out {
92 "sqd" => Ok(Self::Sqd),
93 "hypersync" => Ok(Self::Hypersync),
94 "yellowstone_grpc" => Ok(Self::YellowstoneGrpc),
95 _ => Err(anyhow!("unknown provider kind: {}", out).into()),
96 }
97 }
98}
99
100type DataStream = Pin<Box<dyn Stream<Item = Result<BTreeMap<String, RecordBatch>>> + Send + Sync>>;
101
102fn make_req_fields<T: DeserializeOwned>(query: &cherry_query::Query) -> Result<T> {
103 let mut req_fields_query = query.clone();
104 req_fields_query
105 .add_request_and_include_fields()
106 .context("add req and include fields")?;
107
108 let fields = req_fields_query
109 .fields
110 .into_iter()
111 .map(|(k, v)| {
112 (
113 k.strip_suffix('s').unwrap().to_owned(),
114 v.into_iter()
115 .map(|v| (v, true))
116 .collect::<BTreeMap<String, bool>>(),
117 )
118 })
119 .collect::<BTreeMap<String, _>>();
120
121 Ok(serde_json::from_value(serde_json::to_value(&fields).unwrap()).unwrap())
122}
123
124pub async fn start_stream(mut provider_config: ProviderConfig) -> Result<DataStream> {
125 let generic_query = match &mut provider_config.query {
126 Query::Evm(query) => {
127 let generic_query = evm_query_to_generic(query);
128
129 query.fields = make_req_fields(&generic_query).context("make req fields")?;
130
131 generic_query
132 }
133 Query::Svm(query) => {
134 let generic_query = svm_query_to_generic(query);
135
136 query.fields = make_req_fields(&generic_query).context("make req fields")?;
137
138 generic_query
139 }
140 };
141 let generic_query = Arc::new(generic_query);
142
143 let stream = match provider_config.kind {
144 ProviderKind::Sqd => {
145 provider::sqd::start_stream(provider_config).context("start sqd stream")?
146 }
147 ProviderKind::Hypersync => provider::hypersync::start_stream(provider_config)
148 .await
149 .context("start hypersync stream")?,
150 ProviderKind::YellowstoneGrpc => provider::yellowstone_grpc::start_stream(provider_config)
151 .await
152 .context("start yellowstone_grpc stream")?,
153 };
154
155 let stream = stream.then(move |res| {
156 let generic_query = Arc::clone(&generic_query);
157 async {
158 rayon_async::spawn(move || {
159 res.and_then(move |data| {
160 let data = cherry_query::run_query(&data, &generic_query)
161 .context("run local query")?;
162 Ok(data)
163 })
164 })
165 .await
166 .unwrap()
167 }
168 });
169
170 Ok(Box::pin(stream))
171}