1use rusqlite::{params, Connection};
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct EquitySnapshot {
12 pub timestamp: String,
13 pub equity_usdc: f64,
14 pub realized_pnl: f64,
15 pub unrealized_pnl: f64,
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct EquitySummary {
21 pub current_equity_usdc: f64,
22 pub peak_equity_usdc: f64,
23 pub current_drawdown_usdc: f64,
24 pub current_drawdown_pct: f64,
25 pub max_drawdown_usdc: f64,
26 pub max_drawdown_pct: f64,
27 pub snapshots: Vec<EquitySnapshot>,
28}
29
30pub struct EquityTracker {
32 db: Connection,
33}
34
35impl EquityTracker {
36 pub fn new(db_path: &str) -> Result<Self, Box<dyn std::error::Error>> {
38 let db = Connection::open(db_path)?;
39 db.execute_batch(
40 "CREATE TABLE IF NOT EXISTS equity_snapshots (
41 id INTEGER PRIMARY KEY AUTOINCREMENT,
42 timestamp TEXT NOT NULL,
43 equity_usdc REAL NOT NULL,
44 realized_pnl REAL NOT NULL,
45 unrealized_pnl REAL NOT NULL
46 );
47 CREATE INDEX IF NOT EXISTS idx_equity_ts ON equity_snapshots(timestamp);",
48 )?;
49 Ok(Self { db })
50 }
51
52 pub fn record_snapshot(
54 &self,
55 snapshot: &EquitySnapshot,
56 ) -> Result<(), Box<dyn std::error::Error>> {
57 self.db.execute(
58 "INSERT INTO equity_snapshots (timestamp, equity_usdc, realized_pnl, unrealized_pnl)
59 VALUES (?1, ?2, ?3, ?4)",
60 params![
61 snapshot.timestamp,
62 snapshot.equity_usdc,
63 snapshot.realized_pnl,
64 snapshot.unrealized_pnl,
65 ],
66 )?;
67 Ok(())
68 }
69
70 pub fn get_snapshots(
72 &self,
73 days: u32,
74 ) -> Result<Vec<EquitySnapshot>, Box<dyn std::error::Error>> {
75 let cutoff = chrono::Utc::now() - chrono::Duration::days(i64::from(days));
76 let cutoff_str = cutoff.to_rfc3339();
77
78 let mut stmt = self.db.prepare(
79 "SELECT timestamp, equity_usdc, realized_pnl, unrealized_pnl
80 FROM equity_snapshots
81 WHERE timestamp >= ?1
82 ORDER BY timestamp ASC",
83 )?;
84
85 let rows = stmt.query_map(params![cutoff_str], |row| {
86 Ok(EquitySnapshot {
87 timestamp: row.get(0)?,
88 equity_usdc: row.get(1)?,
89 realized_pnl: row.get(2)?,
90 unrealized_pnl: row.get(3)?,
91 })
92 })?;
93
94 let mut snapshots = Vec::new();
95 for row in rows {
96 snapshots.push(row?);
97 }
98 Ok(snapshots)
99 }
100
101 pub fn get_summary(&self, days: u32) -> Result<EquitySummary, Box<dyn std::error::Error>> {
103 let snapshots = self.get_snapshots(days)?;
104
105 if snapshots.is_empty() {
106 return Ok(EquitySummary {
107 current_equity_usdc: 0.0,
108 peak_equity_usdc: 0.0,
109 current_drawdown_usdc: 0.0,
110 current_drawdown_pct: 0.0,
111 max_drawdown_usdc: 0.0,
112 max_drawdown_pct: 0.0,
113 snapshots,
114 });
115 }
116
117 let current_equity = snapshots.last().map(|s| s.equity_usdc).unwrap_or(0.0);
118
119 let mut peak = f64::NEG_INFINITY;
121 let mut max_dd_usdc: f64 = 0.0;
122 let mut max_dd_pct: f64 = 0.0;
123
124 for snap in &snapshots {
125 if snap.equity_usdc > peak {
126 peak = snap.equity_usdc;
127 }
128 let dd_usdc = peak - snap.equity_usdc;
129 let dd_pct = if peak > 0.0 {
130 dd_usdc / peak * 100.0
131 } else {
132 0.0
133 };
134 if dd_usdc > max_dd_usdc {
135 max_dd_usdc = dd_usdc;
136 max_dd_pct = dd_pct;
137 }
138 }
139
140 let current_dd_usdc = peak - current_equity;
141 let current_dd_pct = if peak > 0.0 {
142 current_dd_usdc / peak * 100.0
143 } else {
144 0.0
145 };
146
147 Ok(EquitySummary {
148 current_equity_usdc: current_equity,
149 peak_equity_usdc: peak,
150 current_drawdown_usdc: current_dd_usdc,
151 current_drawdown_pct: current_dd_pct,
152 max_drawdown_usdc: max_dd_usdc,
153 max_drawdown_pct: max_dd_pct,
154 snapshots,
155 })
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162
163 fn make_tracker() -> EquityTracker {
164 EquityTracker::new(":memory:").unwrap()
165 }
166
167 #[test]
168 fn empty_summary_returns_zeros() {
169 let tracker = make_tracker();
170 let summary = tracker.get_summary(7).unwrap();
171 assert_eq!(summary.current_equity_usdc, 0.0);
172 assert_eq!(summary.max_drawdown_usdc, 0.0);
173 assert!(summary.snapshots.is_empty());
174 }
175
176 #[test]
177 fn records_and_retrieves_snapshots() {
178 let tracker = make_tracker();
179 let now = chrono::Utc::now();
180
181 for i in 0..3 {
182 let ts = (now + chrono::Duration::seconds(i)).to_rfc3339();
183 tracker
184 .record_snapshot(&EquitySnapshot {
185 timestamp: ts,
186 equity_usdc: 10000.0 + (i as f64) * 100.0,
187 realized_pnl: (i as f64) * 50.0,
188 unrealized_pnl: (i as f64) * 50.0,
189 })
190 .unwrap();
191 }
192
193 let snaps = tracker.get_snapshots(1).unwrap();
194 assert_eq!(snaps.len(), 3);
195 }
196
197 #[test]
198 fn drawdown_calculation() {
199 let tracker = make_tracker();
200 let now = chrono::Utc::now();
201
202 let equities = [10000.0, 10500.0, 10000.0, 10200.0];
204 for (i, &eq) in equities.iter().enumerate() {
205 let ts = (now + chrono::Duration::seconds(i as i64)).to_rfc3339();
206 tracker
207 .record_snapshot(&EquitySnapshot {
208 timestamp: ts,
209 equity_usdc: eq,
210 realized_pnl: 0.0,
211 unrealized_pnl: 0.0,
212 })
213 .unwrap();
214 }
215
216 let summary = tracker.get_summary(1).unwrap();
217 assert_eq!(summary.peak_equity_usdc, 10500.0);
218 assert_eq!(summary.current_equity_usdc, 10200.0);
219 assert!((summary.max_drawdown_usdc - 500.0).abs() < 0.01);
221 assert!((summary.max_drawdown_pct - (500.0 / 10500.0 * 100.0)).abs() < 0.01);
222 assert!((summary.current_drawdown_usdc - 300.0).abs() < 0.01);
224 }
225}