1use std::collections::HashMap;
6use std::str::FromStr;
7
8use rust_decimal::Decimal;
9
10use crate::websocket::types::{Balance, BalanceEntry, Order, UserEventData};
11
12fn is_zero(s: &str) -> bool {
14 Decimal::from_str(s)
15 .map(|v| v.is_zero())
16 .unwrap_or(false)
17}
18
19#[derive(Debug, Clone, Default)]
21pub struct UserState {
22 pub user: String,
24 pub orders: HashMap<String, Order>,
26 pub balances: HashMap<String, BalanceEntry>,
28 has_snapshot: bool,
30 last_timestamp: Option<String>,
32}
33
34impl UserState {
35 pub fn new(user: String) -> Self {
37 Self {
38 user,
39 orders: HashMap::new(),
40 balances: HashMap::new(),
41 has_snapshot: false,
42 last_timestamp: None,
43 }
44 }
45
46 pub fn apply_snapshot(&mut self, data: &UserEventData) {
48 self.orders.clear();
50 self.balances.clear();
51
52 for order in &data.orders {
54 self.orders.insert(order.order_hash.clone(), order.clone());
55 }
56
57 for (key, balance) in &data.balances {
59 self.balances.insert(key.clone(), balance.clone());
60 }
61
62 self.has_snapshot = true;
63 self.last_timestamp = data.timestamp.clone();
64 }
65
66 pub fn apply_order_update(&mut self, data: &UserEventData) {
68 if let Some(update) = &data.order {
69 let order_hash = &update.order_hash;
70
71 if is_zero(&update.remaining) {
73 self.orders.remove(order_hash);
74 } else if let Some(existing) = self.orders.get_mut(order_hash) {
75 existing.remaining = update.remaining.clone();
77 existing.filled = update.filled.clone();
78 } else {
79 if let (Some(market_pubkey), Some(orderbook_id)) =
83 (&data.market_pubkey, &data.orderbook_id)
84 {
85 let order = Order {
86 order_hash: order_hash.clone(),
87 market_pubkey: market_pubkey.clone(),
88 orderbook_id: orderbook_id.clone(),
89 side: update.side,
90 maker_amount: update.remaining.clone(), taker_amount: "0".to_string(), remaining: update.remaining.clone(),
93 filled: update.filled.clone(),
94 price: update.price.clone(),
95 created_at: update.created_at,
96 expiration: 0,
97 };
98 self.orders.insert(order_hash.clone(), order);
99 }
100 }
101
102 if let Some(balance) = &update.balance {
104 self.apply_balance_from_order(data, balance);
105 }
106 }
107
108 self.last_timestamp = data.timestamp.clone();
109 }
110
111 pub fn apply_balance_update(&mut self, data: &UserEventData) {
113 if let (Some(market_pubkey), Some(deposit_mint), Some(balance)) =
114 (&data.market_pubkey, &data.deposit_mint, &data.balance)
115 {
116 let key = format!("{}:{}", market_pubkey, deposit_mint);
117 let entry = BalanceEntry {
118 market_pubkey: market_pubkey.clone(),
119 deposit_mint: deposit_mint.clone(),
120 outcomes: balance.outcomes.clone(),
121 };
122 self.balances.insert(key, entry);
123 }
124
125 self.last_timestamp = data.timestamp.clone();
126 }
127
128 fn apply_balance_from_order(&mut self, data: &UserEventData, balance: &Balance) {
130 if let (Some(market_pubkey), Some(deposit_mint)) =
131 (&data.market_pubkey, &data.deposit_mint)
132 {
133 let key = format!("{}:{}", market_pubkey, deposit_mint);
134 let entry = BalanceEntry {
135 market_pubkey: market_pubkey.clone(),
136 deposit_mint: deposit_mint.clone(),
137 outcomes: balance.outcomes.clone(),
138 };
139 self.balances.insert(key, entry);
140 } else if let Some(market_pubkey) = &data.market_pubkey {
141 for (key, entry) in self.balances.iter_mut() {
143 if key.starts_with(market_pubkey) {
144 entry.outcomes = balance.outcomes.clone();
145 break;
146 }
147 }
148 }
149 }
150
151 pub fn apply_event(&mut self, data: &UserEventData) {
153 match data.event_type.as_str() {
154 "snapshot" => self.apply_snapshot(data),
155 "order_update" => self.apply_order_update(data),
156 "balance_update" => self.apply_balance_update(data),
157 _ => {
158 tracing::warn!("Unknown user event type: {}", data.event_type);
159 }
160 }
161 }
162
163 pub fn get_order(&self, order_hash: &str) -> Option<&Order> {
165 self.orders.get(order_hash)
166 }
167
168 pub fn open_orders(&self) -> Vec<&Order> {
170 self.orders.values().collect()
171 }
172
173 pub fn orders_for_market(&self, market_pubkey: &str) -> Vec<&Order> {
175 self.orders
176 .values()
177 .filter(|o| o.market_pubkey == market_pubkey)
178 .collect()
179 }
180
181 pub fn orders_for_orderbook(&self, orderbook_id: &str) -> Vec<&Order> {
183 self.orders
184 .values()
185 .filter(|o| o.orderbook_id == orderbook_id)
186 .collect()
187 }
188
189 pub fn get_balance(&self, market_pubkey: &str, deposit_mint: &str) -> Option<&BalanceEntry> {
191 let key = format!("{}:{}", market_pubkey, deposit_mint);
192 self.balances.get(&key)
193 }
194
195 pub fn all_balances(&self) -> Vec<&BalanceEntry> {
197 self.balances.values().collect()
198 }
199
200 pub fn idle_balance_for_outcome(
202 &self,
203 market_pubkey: &str,
204 deposit_mint: &str,
205 outcome_index: i32,
206 ) -> Option<String> {
207 self.get_balance(market_pubkey, deposit_mint)
208 .and_then(|b| b.outcomes.iter().find(|o| o.outcome_index == outcome_index))
209 .map(|o| o.idle.clone())
210 }
211
212 pub fn on_book_balance_for_outcome(
214 &self,
215 market_pubkey: &str,
216 deposit_mint: &str,
217 outcome_index: i32,
218 ) -> Option<String> {
219 self.get_balance(market_pubkey, deposit_mint)
220 .and_then(|b| b.outcomes.iter().find(|o| o.outcome_index == outcome_index))
221 .map(|o| o.on_book.clone())
222 }
223
224 pub fn order_count(&self) -> usize {
226 self.orders.len()
227 }
228
229 pub fn has_snapshot(&self) -> bool {
231 self.has_snapshot
232 }
233
234 pub fn last_timestamp(&self) -> Option<&str> {
236 self.last_timestamp.as_deref()
237 }
238
239 pub fn clear(&mut self) {
241 self.orders.clear();
242 self.balances.clear();
243 self.has_snapshot = false;
244 self.last_timestamp = None;
245 }
246}
247
248#[cfg(test)]
249mod tests {
250 use super::*;
251 use crate::websocket::types::{OrderUpdate, OutcomeBalance};
252
253 fn create_snapshot() -> UserEventData {
254 UserEventData {
255 event_type: "snapshot".to_string(),
256 orders: vec![Order {
257 order_hash: "hash1".to_string(),
258 market_pubkey: "market1".to_string(),
259 orderbook_id: "ob1".to_string(),
260 side: 0,
261 maker_amount: "0.001000".to_string(),
262 taker_amount: "0.000500".to_string(),
263 remaining: "0.000800".to_string(),
264 filled: "0.000200".to_string(),
265 price: "0.500000".to_string(),
266 created_at: 1704067200000,
267 expiration: 0,
268 }],
269 balances: {
270 let mut map = HashMap::new();
271 map.insert(
272 "market1:mint1".to_string(),
273 BalanceEntry {
274 market_pubkey: "market1".to_string(),
275 deposit_mint: "mint1".to_string(),
276 outcomes: vec![OutcomeBalance {
277 outcome_index: 0,
278 mint: "outcome_mint".to_string(),
279 idle: "0.005000".to_string(),
280 on_book: "0.001000".to_string(),
281 }],
282 },
283 );
284 map
285 },
286 order: None,
287 balance: None,
288 market_pubkey: None,
289 orderbook_id: None,
290 deposit_mint: None,
291 timestamp: Some("2024-01-01T00:00:00.000Z".to_string()),
292 }
293 }
294
295 #[test]
296 fn test_apply_snapshot() {
297 let mut state = UserState::new("user1".to_string());
298 let snapshot = create_snapshot();
299
300 state.apply_snapshot(&snapshot);
301
302 assert!(state.has_snapshot());
303 assert_eq!(state.order_count(), 1);
304 assert!(state.get_order("hash1").is_some());
305 assert!(state.get_balance("market1", "mint1").is_some());
306 }
307
308 #[test]
309 fn test_order_update() {
310 let mut state = UserState::new("user1".to_string());
311 state.apply_snapshot(&create_snapshot());
312
313 let update = UserEventData {
314 event_type: "order_update".to_string(),
315 orders: vec![],
316 balances: HashMap::new(),
317 order: Some(OrderUpdate {
318 order_hash: "hash1".to_string(),
319 price: "0.500000".to_string(),
320 fill_amount: "0.000100".to_string(),
321 remaining: "0.000700".to_string(),
322 filled: "0.000300".to_string(),
323 side: 0,
324 is_maker: true,
325 created_at: 1704067200000,
326 balance: None,
327 }),
328 balance: None,
329 market_pubkey: Some("market1".to_string()),
330 orderbook_id: Some("ob1".to_string()),
331 deposit_mint: None,
332 timestamp: Some("2024-01-01T00:00:01.000Z".to_string()),
333 };
334
335 state.apply_order_update(&update);
336
337 let order = state.get_order("hash1").unwrap();
338 assert_eq!(order.remaining, "0.000700");
339 assert_eq!(order.filled, "0.000300");
340 }
341
342 #[test]
343 fn test_order_removal_on_full_fill() {
344 let mut state = UserState::new("user1".to_string());
345 state.apply_snapshot(&create_snapshot());
346
347 let update = UserEventData {
348 event_type: "order_update".to_string(),
349 orders: vec![],
350 balances: HashMap::new(),
351 order: Some(OrderUpdate {
352 order_hash: "hash1".to_string(),
353 price: "0.500000".to_string(),
354 fill_amount: "0.000800".to_string(),
355 remaining: "0".to_string(), filled: "0.001000".to_string(),
357 side: 0,
358 is_maker: true,
359 created_at: 1704067200000,
360 balance: None,
361 }),
362 balance: None,
363 market_pubkey: Some("market1".to_string()),
364 orderbook_id: Some("ob1".to_string()),
365 deposit_mint: None,
366 timestamp: Some("2024-01-01T00:00:01.000Z".to_string()),
367 };
368
369 state.apply_order_update(&update);
370
371 assert!(state.get_order("hash1").is_none());
372 assert_eq!(state.order_count(), 0);
373 }
374
375 #[test]
376 fn test_balance_update() {
377 let mut state = UserState::new("user1".to_string());
378 state.apply_snapshot(&create_snapshot());
379
380 let update = UserEventData {
381 event_type: "balance_update".to_string(),
382 orders: vec![],
383 balances: HashMap::new(),
384 order: None,
385 balance: Some(Balance {
386 outcomes: vec![OutcomeBalance {
387 outcome_index: 0,
388 mint: "outcome_mint".to_string(),
389 idle: "0.006000".to_string(),
390 on_book: "0.000500".to_string(),
391 }],
392 }),
393 market_pubkey: Some("market1".to_string()),
394 orderbook_id: None,
395 deposit_mint: Some("mint1".to_string()),
396 timestamp: Some("2024-01-01T00:00:01.000Z".to_string()),
397 };
398
399 state.apply_balance_update(&update);
400
401 let balance = state.get_balance("market1", "mint1").unwrap();
402 assert_eq!(balance.outcomes[0].idle, "0.006000");
403 assert_eq!(balance.outcomes[0].on_book, "0.000500");
404 }
405}