Skip to main content

tycho_simulation/rfq/
stream.rs

1use std::collections::HashMap;
2
3use futures::{stream::select_all, StreamExt};
4use tycho_client::feed::{synchronizer::ComponentWithState, FeedMessage};
5use tycho_common::{
6    models::token::Token,
7    simulation::{errors::SimulationError, protocol_sim::ProtocolSim},
8    Bytes,
9};
10
11use crate::{
12    evm::decoder::TychoStreamDecoder,
13    protocol::{
14        errors::InvalidSnapshotError,
15        models::{TryFromWithBlock, Update},
16    },
17    rfq::{client::RFQClient, models::TimestampHeader},
18};
19
20/// `RFQStreamBuilder` is a utility for constructing and managing a merged stream of RFQ (Request
21/// For Quote) providers in Tycho.
22///
23/// It allows you to:
24/// - Register multiple `RFQClient` implementations, each providing its own stream of RFQ price
25///   updates.
26/// - Dynamically decode incoming updates into `Update` objects using `TychoStreamDecoder`.
27///
28/// The `build` method consumes the builder and runs the event loop, sending decoded `Update`s
29/// through the provided `mpsc::Sender`. It returns an error if decoding an update or forwarding
30/// it to the channel fails.
31///
32/// ### Error Handling:
33/// - Each `RFQClient`'s stream is expected to yield `Result<(String, StateSyncMessage), RFQError>`.
34/// - If a client's stream returns an `Err` (e.g., `RFQError::FatalError`), the client is
35///   **removed** from the merged stream, and the system continues running without it.
36#[derive(Default)]
37pub struct RFQStreamBuilder {
38    clients: Vec<Box<dyn RFQClient>>,
39    decoder: TychoStreamDecoder<TimestampHeader>,
40}
41
42impl RFQStreamBuilder {
43    pub fn new() -> Self {
44        Self { clients: Vec::new(), decoder: TychoStreamDecoder::new() }
45    }
46
47    pub fn add_client<T>(mut self, name: &str, provider: Box<dyn RFQClient>) -> Self
48    where
49        T: ProtocolSim
50            + TryFromWithBlock<ComponentWithState, TimestampHeader, Error = InvalidSnapshotError>
51            + Send
52            + 'static,
53    {
54        self.clients.push(provider);
55        self.decoder.register_decoder::<T>(name);
56        self
57    }
58
59    pub async fn build(self, tx: tokio::sync::mpsc::Sender<Update>) -> Result<(), SimulationError> {
60        let streams: Vec<_> = self
61            .clients
62            .into_iter()
63            .map(|provider| provider.stream())
64            .collect();
65
66        let mut merged = select_all(streams);
67
68        while let Some(next) = merged.next().await {
69            match next {
70                Ok((provider, msg)) => {
71                    let update = self
72                        .decoder
73                        .decode(&FeedMessage {
74                            state_msgs: HashMap::from([(provider.clone(), msg)]),
75                            sync_states: HashMap::new(),
76                        })
77                        .await
78                        .map_err(|e| {
79                            SimulationError::RecoverableError(format!("Decoding error: {e}"))
80                        })?;
81                    tx.send(update).await.map_err(|e| {
82                        SimulationError::RecoverableError(format!(
83                            "Failed to send update through channel: {e}"
84                        ))
85                    })?;
86                }
87                Err(e) => {
88                    tracing::error!(
89                        "RFQ stream fatal error: {e}. Assuming this stream will not emit more messages."
90                    );
91                }
92            }
93        }
94
95        Ok(())
96    }
97
98    /// Provides token metadata used to decode startup snapshots and initialize protocol states.
99    ///
100    /// This is not an ongoing stream filter. Components arriving after startup include their
101    /// own token metadata for decoding.
102    pub async fn set_tokens(self, tokens: HashMap<Bytes, Token>) -> Self {
103        self.decoder.set_tokens(tokens).await;
104        self
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use std::{any::Any, time::Duration};
111
112    use async_trait::async_trait;
113    use futures::stream::BoxStream;
114    use num_bigint::BigUint;
115    use serde::{Deserialize, Serialize};
116    use tokio::sync::mpsc;
117    use tokio_stream::wrappers::IntervalStream;
118    use tycho_client::feed::synchronizer::{Snapshot, StateSyncMessage};
119    use tycho_common::{
120        dto::ProtocolStateDelta,
121        models::{
122            protocol::{GetAmountOutParams, ProtocolComponent, ProtocolComponentState},
123            token::Token,
124        },
125        simulation::{
126            errors::{SimulationError, TransitionError},
127            indicatively_priced::SignedQuote,
128            protocol_sim::{Balances, GetAmountOutResult},
129        },
130        Bytes,
131    };
132
133    use super::*;
134    use crate::{protocol::models::DecoderContext, rfq::errors::RFQError};
135
136    #[derive(Clone, Debug, Serialize, Deserialize)]
137    pub struct DummyProtocol;
138
139    #[typetag::serde]
140    impl ProtocolSim for DummyProtocol {
141        fn fee(&self) -> f64 {
142            unimplemented!("Not needed for this test")
143        }
144
145        fn spot_price(&self, _base: &Token, _quote: &Token) -> Result<f64, SimulationError> {
146            unimplemented!("Not needed for this test")
147        }
148
149        fn get_amount_out(
150            &self,
151            _amount_in: BigUint,
152            _token_in: &Token,
153            _token_out: &Token,
154        ) -> Result<GetAmountOutResult, SimulationError> {
155            unimplemented!("Not needed for this test")
156        }
157
158        fn get_limits(
159            &self,
160            _sell_token: Bytes,
161            _buy_token: Bytes,
162        ) -> Result<(BigUint, BigUint), SimulationError> {
163            unimplemented!("Not needed for this test")
164        }
165
166        fn delta_transition(
167            &mut self,
168            _delta: ProtocolStateDelta,
169            _tokens: &HashMap<Bytes, Token>,
170            _balances: &Balances,
171        ) -> Result<(), TransitionError> {
172            unimplemented!("Not needed for this test")
173        }
174
175        fn clone_box(&self) -> Box<dyn ProtocolSim> {
176            Box::new(self.clone())
177        }
178
179        fn as_any(&self) -> &dyn Any {
180            self
181        }
182
183        fn as_any_mut(&mut self) -> &mut dyn Any {
184            self
185        }
186        fn eq(&self, _other: &dyn ProtocolSim) -> bool {
187            unimplemented!("Not needed for this test")
188        }
189    }
190
191    impl TryFromWithBlock<ComponentWithState, TimestampHeader> for DummyProtocol {
192        type Error = InvalidSnapshotError;
193        async fn try_from_with_header(
194            _value: ComponentWithState,
195            _header: TimestampHeader,
196            _account_balances: &HashMap<Bytes, HashMap<Bytes, Bytes>>,
197            _all_tokens: &HashMap<Bytes, Token>,
198            _decoder_context: &DecoderContext,
199        ) -> Result<Self, Self::Error> {
200            Ok(DummyProtocol)
201        }
202    }
203
204    pub struct MockRFQClient {
205        name: String,
206        interval: Duration,
207        error_at_time: Option<u128>,
208    }
209
210    impl MockRFQClient {
211        pub fn new(name: &str, interval: Duration, error_at_time: Option<u128>) -> Self {
212            Self { name: name.to_string(), interval, error_at_time }
213        }
214    }
215
216    #[async_trait]
217    impl RFQClient for MockRFQClient {
218        fn stream(
219            &self,
220        ) -> BoxStream<'static, Result<(String, StateSyncMessage<TimestampHeader>), RFQError>>
221        {
222            let name = self.name.clone();
223            let error_at_time = self.error_at_time;
224            let mut current_time: u128 = 0;
225            let interval = self.interval;
226            let interval =
227                IntervalStream::new(tokio::time::interval(self.interval)).map(move |_| {
228                    if let Some(error_at_time) = error_at_time {
229                        if error_at_time == current_time {
230                            return Err(RFQError::FatalError(format!(
231                                "{name} stream is dying and can't go on"
232                            )));
233                        };
234                    };
235                    let protocol_component =
236                        ProtocolComponent { protocol_system: name.clone(), ..Default::default() };
237
238                    let snapshot = Snapshot {
239                        states: HashMap::from([(
240                            name.clone(),
241                            ComponentWithState {
242                                state: ProtocolComponentState {
243                                    component_id: name.clone(),
244                                    attributes: HashMap::new(),
245                                    balances: HashMap::new(),
246                                },
247                                component: protocol_component,
248                                component_tvl: None,
249                                entrypoints: vec![],
250                            },
251                        )]),
252                        vm_storage: HashMap::new(),
253                    };
254
255                    let msg = StateSyncMessage {
256                        header: TimestampHeader { timestamp: current_time as u64 },
257                        snapshots: snapshot,
258                        ..Default::default()
259                    };
260
261                    current_time += interval.as_millis();
262                    Ok((name.clone(), msg))
263                });
264            Box::pin(interval)
265        }
266
267        async fn request_binding_quote(
268            &self,
269            _params: &GetAmountOutParams,
270        ) -> Result<SignedQuote, RFQError> {
271            unimplemented!("Not needed for this test")
272        }
273    }
274
275    #[tokio::test]
276    async fn test_rfq_stream_builder() {
277        // This test has two mocked RFQ clients
278        // 1. Bebop client that emits a message every 100ms
279        // 2. Hashflow client that emits a message every 200m
280        let (tx, mut rx) = mpsc::channel::<Update>(10);
281
282        let builder = RFQStreamBuilder::new()
283            .add_client::<DummyProtocol>(
284                "bebop",
285                Box::new(MockRFQClient::new("bebop", Duration::from_millis(100), Some(300))),
286            )
287            .add_client::<DummyProtocol>(
288                "hashflow",
289                Box::new(MockRFQClient::new("hashflow", Duration::from_millis(200), None)),
290            );
291
292        tokio::spawn(builder.build(tx));
293
294        // Collect only the first 10 messages
295        let mut updates = Vec::new();
296        for _ in 0..6 {
297            let update = rx.recv().await.unwrap();
298            updates.push(update);
299        }
300
301        // Collect all timestamps per provider
302        let bebop_updates: Vec<_> = updates
303            .iter()
304            .filter(|u| u.new_pairs.contains_key("bebop"))
305            .collect();
306        let hashflow_updates: Vec<_> = updates
307            .iter()
308            .filter(|u| u.new_pairs.contains_key("hashflow"))
309            .collect();
310
311        assert_eq!(bebop_updates[0].block_number_or_timestamp, 0,);
312        assert_eq!(hashflow_updates[0].block_number_or_timestamp, 0,);
313        assert_eq!(bebop_updates[1].block_number_or_timestamp, 100);
314        assert_eq!(bebop_updates[2].block_number_or_timestamp, 200);
315        assert_eq!(hashflow_updates[1].block_number_or_timestamp, 200);
316        // At this point the bebop stream dies, and we shouldn't have any more bebop updates, only
317        // hashflow
318        assert_eq!(bebop_updates.len(), 3);
319        assert_eq!(hashflow_updates[2].block_number_or_timestamp, 400);
320    }
321}