1use crate::types::{NetPosition, NettingConfig, NettingResult, PartySummary, Trade, TradeStatus};
9use rustkernel_core::{domain::Domain, kernel::KernelMetadata, traits::GpuKernel};
10use std::collections::HashMap;
11
12#[derive(Debug, Clone)]
20pub struct NettingCalculation {
21 metadata: KernelMetadata,
22}
23
24impl Default for NettingCalculation {
25 fn default() -> Self {
26 Self::new()
27 }
28}
29
30impl NettingCalculation {
31 #[must_use]
33 pub fn new() -> Self {
34 Self {
35 metadata: KernelMetadata::batch("clearing/netting", Domain::Clearing)
36 .with_description("Multilateral netting calculation")
37 .with_throughput(10_000)
38 .with_latency_us(500.0),
39 }
40 }
41
42 pub fn calculate(trades: &[Trade], config: &NettingConfig) -> NettingResult {
44 let eligible_trades: Vec<_> = trades
46 .iter()
47 .filter(|t| {
48 if !config.include_failed {
49 matches!(
50 t.status,
51 TradeStatus::Pending | TradeStatus::Validated | TradeStatus::Matched
52 )
53 } else {
54 true
55 }
56 })
57 .collect();
58
59 let gross_trade_count = eligible_trades.len() as u64;
60
61 let get_key = |trade: &Trade, party: &str| -> String {
63 let mut key = party.to_string();
64 if config.net_by_security {
65 key.push_str(&format!(":{}", trade.security_id));
66 }
67 if config.net_by_settlement_date {
68 key.push_str(&format!(":{}", trade.settlement_date));
69 }
70 if config.net_by_currency {
71 let currency = trade
73 .attributes
74 .get("currency")
75 .map(|s| s.as_str())
76 .unwrap_or("USD");
77 key.push_str(&format!(":{}", currency));
78 }
79 key
80 };
81
82 let get_currency = |trade: &Trade| -> String {
84 trade
85 .attributes
86 .get("currency")
87 .cloned()
88 .unwrap_or_else(|| "USD".to_string())
89 };
90
91 let mut positions_map: HashMap<String, NetPositionBuilder> = HashMap::new();
93
94 for trade in &eligible_trades {
95 let currency = get_currency(trade);
96
97 let buyer_key = get_key(trade, &trade.buyer_id);
99 let buyer_pos = positions_map.entry(buyer_key).or_insert_with(|| {
100 NetPositionBuilder::new(
101 trade.buyer_id.clone(),
102 trade.security_id.clone(),
103 currency.clone(),
104 )
105 });
106 buyer_pos.add_receive(trade.quantity, trade.value(), trade.id);
107
108 let seller_key = get_key(trade, &trade.seller_id);
110 let seller_pos = positions_map.entry(seller_key).or_insert_with(|| {
111 NetPositionBuilder::new(
112 trade.seller_id.clone(),
113 trade.security_id.clone(),
114 currency,
115 )
116 });
117 seller_pos.add_deliver(trade.quantity, trade.value(), trade.id);
118 }
119
120 let positions: Vec<_> = positions_map
122 .into_values()
123 .map(|builder| builder.build())
124 .collect();
125
126 let net_instruction_count = positions.len() as u64;
127
128 let efficiency = if gross_trade_count > 0 {
130 1.0 - (net_instruction_count as f64 / (gross_trade_count * 2) as f64)
131 } else {
132 0.0
133 };
134
135 let mut party_summary: HashMap<String, PartySummary> = HashMap::new();
137
138 for pos in &positions {
139 let summary = party_summary.entry(pos.party_id.clone()).or_default();
140
141 if pos.net_quantity > 0 {
142 summary.gross_receipts += pos.net_quantity;
143 } else {
144 summary.gross_deliveries += pos.net_quantity.unsigned_abs() as i64;
145 }
146 summary.net_position += pos.net_quantity;
147
148 if pos.net_payment > 0 {
149 summary.gross_payments -= pos.net_payment; } else {
151 summary.gross_payments += pos.net_payment.unsigned_abs() as i64;
152 }
153 summary.net_payment += pos.net_payment;
154 summary.trade_count += pos.trade_ids.len() as u64;
155 }
156
157 NettingResult {
158 positions,
159 gross_trade_count,
160 net_instruction_count,
161 efficiency,
162 party_summary,
163 }
164 }
165
166 pub fn calculate_bilateral(
168 trades: &[Trade],
169 party_a: &str,
170 party_b: &str,
171 ) -> BilateralNetResult {
172 let relevant_trades: Vec<_> = trades
173 .iter()
174 .filter(|t| {
175 (t.buyer_id == party_a && t.seller_id == party_b)
176 || (t.buyer_id == party_b && t.seller_id == party_a)
177 })
178 .collect();
179
180 let mut a_receives = 0i64;
181 let mut a_delivers = 0i64;
182 let mut a_pays = 0i64;
183 let mut a_collects = 0i64;
184
185 for trade in &relevant_trades {
186 if trade.buyer_id == party_a {
187 a_receives += trade.quantity;
189 a_pays += trade.value();
190 } else {
191 a_delivers += trade.quantity;
193 a_collects += trade.value();
194 }
195 }
196
197 BilateralNetResult {
198 party_a: party_a.to_string(),
199 party_b: party_b.to_string(),
200 trade_count: relevant_trades.len() as u64,
201 net_securities_a: a_receives - a_delivers,
202 net_payment_a: a_collects - a_pays,
203 }
204 }
205
206 pub fn stats_by_security(result: &NettingResult) -> HashMap<String, SecurityNettingStats> {
208 let mut stats: HashMap<String, SecurityNettingStats> = HashMap::new();
209
210 for pos in &result.positions {
211 let stat =
212 stats
213 .entry(pos.security_id.clone())
214 .or_insert_with(|| SecurityNettingStats {
215 security_id: pos.security_id.clone(),
216 total_net_positions: 0,
217 total_trades: 0,
218 net_quantity: 0,
219 gross_volume: 0,
220 });
221
222 stat.total_net_positions += 1;
223 stat.total_trades += pos.trade_ids.len() as u64;
224 stat.net_quantity += pos.net_quantity.unsigned_abs() as i64;
225 stat.gross_volume += pos.trade_ids.len() as i64; }
227
228 stats
229 }
230}
231
232impl GpuKernel for NettingCalculation {
233 fn metadata(&self) -> &KernelMetadata {
234 &self.metadata
235 }
236}
237
238struct NetPositionBuilder {
240 party_id: String,
241 security_id: String,
242 currency: String,
243 receives: i64,
244 delivers: i64,
245 payments_in: i64,
246 payments_out: i64,
247 trade_ids: Vec<u64>,
248}
249
250impl NetPositionBuilder {
251 fn new(party_id: String, security_id: String, currency: String) -> Self {
252 Self {
253 party_id,
254 security_id,
255 currency,
256 receives: 0,
257 delivers: 0,
258 payments_in: 0,
259 payments_out: 0,
260 trade_ids: Vec::new(),
261 }
262 }
263
264 fn add_receive(&mut self, quantity: i64, payment: i64, trade_id: u64) {
265 self.receives += quantity;
266 self.payments_out += payment;
267 if !self.trade_ids.contains(&trade_id) {
268 self.trade_ids.push(trade_id);
269 }
270 }
271
272 fn add_deliver(&mut self, quantity: i64, payment: i64, trade_id: u64) {
273 self.delivers += quantity;
274 self.payments_in += payment;
275 if !self.trade_ids.contains(&trade_id) {
276 self.trade_ids.push(trade_id);
277 }
278 }
279
280 fn build(self) -> NetPosition {
281 NetPosition {
282 party_id: self.party_id,
283 security_id: self.security_id,
284 net_quantity: self.receives - self.delivers,
285 net_payment: self.payments_in - self.payments_out,
286 currency: self.currency,
287 trade_ids: self.trade_ids,
288 }
289 }
290}
291
292#[derive(Debug, Clone)]
294pub struct BilateralNetResult {
295 pub party_a: String,
297 pub party_b: String,
299 pub trade_count: u64,
301 pub net_securities_a: i64,
303 pub net_payment_a: i64,
305}
306
307#[derive(Debug, Clone)]
309pub struct SecurityNettingStats {
310 pub security_id: String,
312 pub total_net_positions: u64,
314 pub total_trades: u64,
316 pub net_quantity: i64,
318 pub gross_volume: i64,
320}
321
322#[cfg(test)]
323mod tests {
324 use super::*;
325
326 fn create_test_trades() -> Vec<Trade> {
327 vec![
328 Trade::new(
329 1,
330 "AAPL".to_string(),
331 "A".to_string(),
332 "B".to_string(),
333 100,
334 150,
335 1700000000,
336 1700172800,
337 ),
338 Trade::new(
339 2,
340 "AAPL".to_string(),
341 "B".to_string(),
342 "A".to_string(),
343 50,
344 151,
345 1700000100,
346 1700172800,
347 ),
348 Trade::new(
349 3,
350 "AAPL".to_string(),
351 "A".to_string(),
352 "C".to_string(),
353 30,
354 149,
355 1700000200,
356 1700172800,
357 ),
358 Trade::new(
359 4,
360 "MSFT".to_string(),
361 "A".to_string(),
362 "B".to_string(),
363 200,
364 300,
365 1700000300,
366 1700172800,
367 ),
368 ]
369 }
370
371 #[test]
372 fn test_netting_metadata() {
373 let kernel = NettingCalculation::new();
374 assert_eq!(kernel.metadata().id, "clearing/netting");
375 assert_eq!(kernel.metadata().domain, Domain::Clearing);
376 }
377
378 #[test]
379 fn test_basic_netting() {
380 let trades = create_test_trades();
381 let config = NettingConfig::default();
382
383 let result = NettingCalculation::calculate(&trades, &config);
384
385 assert_eq!(result.gross_trade_count, 4);
386 assert!(result.net_instruction_count < result.gross_trade_count * 2);
387 assert!(result.efficiency > 0.0);
388 }
389
390 #[test]
391 fn test_net_positions() {
392 let trades = vec![
393 Trade::new(
394 1,
395 "AAPL".to_string(),
396 "A".to_string(),
397 "B".to_string(),
398 100,
399 150,
400 1700000000,
401 1700172800,
402 ),
403 Trade::new(
404 2,
405 "AAPL".to_string(),
406 "B".to_string(),
407 "A".to_string(),
408 100,
409 150,
410 1700000100,
411 1700172800,
412 ),
413 ];
414 let config = NettingConfig::default();
415
416 let result = NettingCalculation::calculate(&trades, &config);
417
418 let a_pos = result
420 .positions
421 .iter()
422 .find(|p| p.party_id == "A" && p.security_id == "AAPL");
423 if let Some(pos) = a_pos {
424 assert_eq!(pos.net_quantity, 0);
425 }
426 }
427
428 #[test]
429 fn test_bilateral_netting() {
430 let trades = create_test_trades();
431
432 let result = NettingCalculation::calculate_bilateral(&trades, "A", "B");
433
434 assert_eq!(result.trade_count, 3); assert!(result.net_securities_a > 0); }
441
442 #[test]
443 fn test_netting_efficiency() {
444 let trades = vec![
446 Trade::new(
447 1,
448 "AAPL".to_string(),
449 "A".to_string(),
450 "B".to_string(),
451 100,
452 150,
453 1700000000,
454 1700172800,
455 ),
456 Trade::new(
457 2,
458 "AAPL".to_string(),
459 "B".to_string(),
460 "A".to_string(),
461 100,
462 150,
463 1700000100,
464 1700172800,
465 ),
466 Trade::new(
467 3,
468 "AAPL".to_string(),
469 "A".to_string(),
470 "B".to_string(),
471 100,
472 150,
473 1700000200,
474 1700172800,
475 ),
476 Trade::new(
477 4,
478 "AAPL".to_string(),
479 "B".to_string(),
480 "A".to_string(),
481 100,
482 150,
483 1700000300,
484 1700172800,
485 ),
486 ];
487 let config = NettingConfig::default();
488
489 let result = NettingCalculation::calculate(&trades, &config);
490
491 assert!(result.efficiency > 0.5);
493 }
494
495 #[test]
496 fn test_party_summary() {
497 let trades = create_test_trades();
498 let config = NettingConfig::default();
499
500 let result = NettingCalculation::calculate(&trades, &config);
501
502 assert!(result.party_summary.contains_key("A"));
503 assert!(result.party_summary.contains_key("B"));
504 assert!(result.party_summary.contains_key("C"));
505 }
506
507 #[test]
508 fn test_net_by_security() {
509 let trades = create_test_trades();
510 let config = NettingConfig::default();
511
512 let result = NettingCalculation::calculate(&trades, &config);
513
514 let aapl_positions: Vec<_> = result
516 .positions
517 .iter()
518 .filter(|p| p.security_id == "AAPL")
519 .collect();
520 let msft_positions: Vec<_> = result
521 .positions
522 .iter()
523 .filter(|p| p.security_id == "MSFT")
524 .collect();
525
526 assert!(!aapl_positions.is_empty());
527 assert!(!msft_positions.is_empty());
528 }
529
530 #[test]
531 fn test_exclude_failed_trades() {
532 let mut trades = create_test_trades();
533 trades[0].status = TradeStatus::Failed;
534
535 let config = NettingConfig::default();
536
537 let result = NettingCalculation::calculate(&trades, &config);
538
539 assert_eq!(result.gross_trade_count, 3); }
541
542 #[test]
543 fn test_include_failed_trades() {
544 let mut trades = create_test_trades();
545 trades[0].status = TradeStatus::Failed;
546
547 let config = NettingConfig {
548 include_failed: true,
549 ..NettingConfig::default()
550 };
551
552 let result = NettingCalculation::calculate(&trades, &config);
553
554 assert_eq!(result.gross_trade_count, 4); }
556
557 #[test]
558 fn test_stats_by_security() {
559 let trades = create_test_trades();
560 let config = NettingConfig::default();
561
562 let result = NettingCalculation::calculate(&trades, &config);
563 let stats = NettingCalculation::stats_by_security(&result);
564
565 assert!(stats.contains_key("AAPL"));
566 assert!(stats.contains_key("MSFT"));
567 }
568
569 #[test]
570 fn test_empty_trades() {
571 let trades: Vec<Trade> = vec![];
572 let config = NettingConfig::default();
573
574 let result = NettingCalculation::calculate(&trades, &config);
575
576 assert_eq!(result.gross_trade_count, 0);
577 assert_eq!(result.net_instruction_count, 0);
578 assert!((result.efficiency - 0.0).abs() < 0.001);
579 }
580}