1use std::collections::HashMap;
2
3use futures::{stream::select_all, StreamExt};
4use tycho_client::feed::{synchronizer::ComponentWithState, FeedMessage};
5use tycho_common::{
6 models::token::Token,
7 simulation::{errors::SimulationError, protocol_sim::ProtocolSim},
8 Bytes,
9};
10
11use crate::{
12 evm::decoder::TychoStreamDecoder,
13 protocol::{
14 errors::InvalidSnapshotError,
15 models::{TryFromWithBlock, Update},
16 },
17 rfq::{client::RFQClient, models::TimestampHeader},
18};
19
20#[derive(Default)]
37pub struct RFQStreamBuilder {
38 clients: Vec<Box<dyn RFQClient>>,
39 decoder: TychoStreamDecoder<TimestampHeader>,
40}
41
42impl RFQStreamBuilder {
43 pub fn new() -> Self {
44 Self { clients: Vec::new(), decoder: TychoStreamDecoder::new() }
45 }
46
47 pub fn add_client<T>(mut self, name: &str, provider: Box<dyn RFQClient>) -> Self
48 where
49 T: ProtocolSim
50 + TryFromWithBlock<ComponentWithState, TimestampHeader, Error = InvalidSnapshotError>
51 + Send
52 + 'static,
53 {
54 self.clients.push(provider);
55 self.decoder.register_decoder::<T>(name);
56 self
57 }
58
59 pub async fn build(self, tx: tokio::sync::mpsc::Sender<Update>) -> Result<(), SimulationError> {
60 let streams: Vec<_> = self
61 .clients
62 .into_iter()
63 .map(|provider| provider.stream())
64 .collect();
65
66 let mut merged = select_all(streams);
67
68 while let Some(next) = merged.next().await {
69 match next {
70 Ok((provider, msg)) => {
71 let update = self
72 .decoder
73 .decode(&FeedMessage {
74 state_msgs: HashMap::from([(provider.clone(), msg)]),
75 sync_states: HashMap::new(),
76 })
77 .await
78 .map_err(|e| {
79 SimulationError::RecoverableError(format!("Decoding error: {e}"))
80 })?;
81 tx.send(update).await.map_err(|e| {
82 SimulationError::RecoverableError(format!(
83 "Failed to send update through channel: {e}"
84 ))
85 })?;
86 }
87 Err(e) => {
88 tracing::error!(
89 "RFQ stream fatal error: {e}. Assuming this stream will not emit more messages."
90 );
91 }
92 }
93 }
94
95 Ok(())
96 }
97
98 pub async fn set_tokens(self, tokens: HashMap<Bytes, Token>) -> Self {
103 self.decoder.set_tokens(tokens).await;
104 self
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use std::{any::Any, time::Duration};
111
112 use async_trait::async_trait;
113 use futures::stream::BoxStream;
114 use num_bigint::BigUint;
115 use serde::{Deserialize, Serialize};
116 use tokio::sync::mpsc;
117 use tokio_stream::wrappers::IntervalStream;
118 use tycho_client::feed::synchronizer::{Snapshot, StateSyncMessage};
119 use tycho_common::{
120 dto::ProtocolStateDelta,
121 models::{
122 protocol::{GetAmountOutParams, ProtocolComponent, ProtocolComponentState},
123 token::Token,
124 },
125 simulation::{
126 errors::{SimulationError, TransitionError},
127 indicatively_priced::SignedQuote,
128 protocol_sim::{Balances, GetAmountOutResult},
129 },
130 Bytes,
131 };
132
133 use super::*;
134 use crate::{protocol::models::DecoderContext, rfq::errors::RFQError};
135
136 #[derive(Clone, Debug, Serialize, Deserialize)]
137 pub struct DummyProtocol;
138
139 #[typetag::serde]
140 impl ProtocolSim for DummyProtocol {
141 fn fee(&self) -> f64 {
142 unimplemented!("Not needed for this test")
143 }
144
145 fn spot_price(&self, _base: &Token, _quote: &Token) -> Result<f64, SimulationError> {
146 unimplemented!("Not needed for this test")
147 }
148
149 fn get_amount_out(
150 &self,
151 _amount_in: BigUint,
152 _token_in: &Token,
153 _token_out: &Token,
154 ) -> Result<GetAmountOutResult, SimulationError> {
155 unimplemented!("Not needed for this test")
156 }
157
158 fn get_limits(
159 &self,
160 _sell_token: Bytes,
161 _buy_token: Bytes,
162 ) -> Result<(BigUint, BigUint), SimulationError> {
163 unimplemented!("Not needed for this test")
164 }
165
166 fn delta_transition(
167 &mut self,
168 _delta: ProtocolStateDelta,
169 _tokens: &HashMap<Bytes, Token>,
170 _balances: &Balances,
171 ) -> Result<(), TransitionError> {
172 unimplemented!("Not needed for this test")
173 }
174
175 fn clone_box(&self) -> Box<dyn ProtocolSim> {
176 Box::new(self.clone())
177 }
178
179 fn as_any(&self) -> &dyn Any {
180 self
181 }
182
183 fn as_any_mut(&mut self) -> &mut dyn Any {
184 self
185 }
186 fn eq(&self, _other: &dyn ProtocolSim) -> bool {
187 unimplemented!("Not needed for this test")
188 }
189 }
190
191 impl TryFromWithBlock<ComponentWithState, TimestampHeader> for DummyProtocol {
192 type Error = InvalidSnapshotError;
193 async fn try_from_with_header(
194 _value: ComponentWithState,
195 _header: TimestampHeader,
196 _account_balances: &HashMap<Bytes, HashMap<Bytes, Bytes>>,
197 _all_tokens: &HashMap<Bytes, Token>,
198 _decoder_context: &DecoderContext,
199 ) -> Result<Self, Self::Error> {
200 Ok(DummyProtocol)
201 }
202 }
203
204 pub struct MockRFQClient {
205 name: String,
206 interval: Duration,
207 error_at_time: Option<u128>,
208 }
209
210 impl MockRFQClient {
211 pub fn new(name: &str, interval: Duration, error_at_time: Option<u128>) -> Self {
212 Self { name: name.to_string(), interval, error_at_time }
213 }
214 }
215
216 #[async_trait]
217 impl RFQClient for MockRFQClient {
218 fn stream(
219 &self,
220 ) -> BoxStream<'static, Result<(String, StateSyncMessage<TimestampHeader>), RFQError>>
221 {
222 let name = self.name.clone();
223 let error_at_time = self.error_at_time;
224 let mut current_time: u128 = 0;
225 let interval = self.interval;
226 let interval =
227 IntervalStream::new(tokio::time::interval(self.interval)).map(move |_| {
228 if let Some(error_at_time) = error_at_time {
229 if error_at_time == current_time {
230 return Err(RFQError::FatalError(format!(
231 "{name} stream is dying and can't go on"
232 )));
233 };
234 };
235 let protocol_component =
236 ProtocolComponent { protocol_system: name.clone(), ..Default::default() };
237
238 let snapshot = Snapshot {
239 states: HashMap::from([(
240 name.clone(),
241 ComponentWithState {
242 state: ProtocolComponentState {
243 component_id: name.clone(),
244 attributes: HashMap::new(),
245 balances: HashMap::new(),
246 },
247 component: protocol_component,
248 component_tvl: None,
249 entrypoints: vec![],
250 },
251 )]),
252 vm_storage: HashMap::new(),
253 };
254
255 let msg = StateSyncMessage {
256 header: TimestampHeader { timestamp: current_time as u64 },
257 snapshots: snapshot,
258 ..Default::default()
259 };
260
261 current_time += interval.as_millis();
262 Ok((name.clone(), msg))
263 });
264 Box::pin(interval)
265 }
266
267 async fn request_binding_quote(
268 &self,
269 _params: &GetAmountOutParams,
270 ) -> Result<SignedQuote, RFQError> {
271 unimplemented!("Not needed for this test")
272 }
273 }
274
275 #[tokio::test]
276 async fn test_rfq_stream_builder() {
277 let (tx, mut rx) = mpsc::channel::<Update>(10);
281
282 let builder = RFQStreamBuilder::new()
283 .add_client::<DummyProtocol>(
284 "bebop",
285 Box::new(MockRFQClient::new("bebop", Duration::from_millis(100), Some(300))),
286 )
287 .add_client::<DummyProtocol>(
288 "hashflow",
289 Box::new(MockRFQClient::new("hashflow", Duration::from_millis(200), None)),
290 );
291
292 tokio::spawn(builder.build(tx));
293
294 let mut updates = Vec::new();
296 for _ in 0..6 {
297 let update = rx.recv().await.unwrap();
298 updates.push(update);
299 }
300
301 let bebop_updates: Vec<_> = updates
303 .iter()
304 .filter(|u| u.new_pairs.contains_key("bebop"))
305 .collect();
306 let hashflow_updates: Vec<_> = updates
307 .iter()
308 .filter(|u| u.new_pairs.contains_key("hashflow"))
309 .collect();
310
311 assert_eq!(bebop_updates[0].block_number_or_timestamp, 0,);
312 assert_eq!(hashflow_updates[0].block_number_or_timestamp, 0,);
313 assert_eq!(bebop_updates[1].block_number_or_timestamp, 100);
314 assert_eq!(bebop_updates[2].block_number_or_timestamp, 200);
315 assert_eq!(hashflow_updates[1].block_number_or_timestamp, 200);
316 assert_eq!(bebop_updates.len(), 3);
319 assert_eq!(hashflow_updates[2].block_number_or_timestamp, 400);
320 }
321}