1use latch_core::{MeterConfig, MeterRejectReason, MeterVerdict, SessionUsage};
2use std::collections::HashMap;
3
4#[derive(Default)]
5pub struct UsageMeter {
6 sessions: HashMap<String, SessionUsage>,
7}
8
9impl UsageMeter {
10 pub fn new() -> Self {
11 Self::default()
12 }
13
14 pub fn get_session_usage(&self, session_id: &str) -> Option<&SessionUsage> {
15 self.sessions.get(session_id)
16 }
17
18 pub fn preview_request(
19 &self,
20 session_id: &str,
21 cfg: &MeterConfig,
22 predicted_input_tokens: u64,
23 predicted_output_tokens: u64,
24 ) -> MeterVerdict {
25 let current = self
26 .sessions
27 .get(session_id)
28 .cloned()
29 .unwrap_or_default();
30 let next_input = current.input_tokens.saturating_add(predicted_input_tokens);
31 let next_output = current.output_tokens.saturating_add(predicted_output_tokens);
32 let next_requests = current.requests.saturating_add(1);
33
34 if let Some(limit) = cfg.session_token_limit {
35 if next_input.saturating_add(next_output) > limit {
36 return MeterVerdict::Reject(MeterRejectReason::SessionTokenLimitExceeded);
37 }
38 }
39
40 if let Some(limit) = cfg.session_request_limit {
41 if next_requests > limit {
42 return MeterVerdict::Reject(MeterRejectReason::SessionRequestLimitExceeded);
43 }
44 }
45
46 MeterVerdict::Allow
47 }
48
49 pub fn record_request(
50 &mut self,
51 session_id: &str,
52 cfg: &MeterConfig,
53 input_tokens: u64,
54 output_tokens: u64,
55 ) -> SessionUsage {
56 let entry = self.sessions.entry(session_id.to_string()).or_default();
57 entry.input_tokens = entry.input_tokens.saturating_add(input_tokens);
58 entry.output_tokens = entry.output_tokens.saturating_add(output_tokens);
59 entry.requests = entry.requests.saturating_add(1);
60 entry.estimated_cost = estimate_cost(cfg, entry.input_tokens, entry.output_tokens);
61 entry.clone()
62 }
63}
64
65pub fn estimate_cost(cfg: &MeterConfig, input_tokens: u64, output_tokens: u64) -> f64 {
66 let input_cost = (input_tokens as f64 / 1000.0) * cfg.price_per_1k_input_tokens;
67 let output_cost = (output_tokens as f64 / 1000.0) * cfg.price_per_1k_output_tokens;
68 input_cost + output_cost
69}
70
71#[cfg(test)]
72mod tests {
73 use super::{estimate_cost, UsageMeter};
74 use latch_core::{MeterConfig, MeterRejectReason, MeterVerdict};
75
76 fn cfg() -> MeterConfig {
77 MeterConfig {
78 session_token_limit: Some(1000),
79 session_request_limit: Some(2),
80 price_per_1k_input_tokens: 0.01,
81 price_per_1k_output_tokens: 0.03,
82 currency: "USD".to_string(),
83 }
84 }
85
86 #[test]
87 fn tracks_usage_and_cost() {
88 let mut meter = UsageMeter::new();
89 let usage = meter.record_request("s1", &cfg(), 200, 100);
90 assert_eq!(usage.input_tokens, 200);
91 assert_eq!(usage.output_tokens, 100);
92 assert_eq!(usage.requests, 1);
93 assert!((usage.estimated_cost - 0.005).abs() < 1e-9);
94 }
95
96 #[test]
97 fn rejects_when_token_limit_would_be_exceeded() {
98 let mut meter = UsageMeter::new();
99 meter.record_request("s1", &cfg(), 800, 100);
100 let verdict = meter.preview_request("s1", &cfg(), 80, 50);
101 assert_eq!(
102 verdict,
103 MeterVerdict::Reject(MeterRejectReason::SessionTokenLimitExceeded)
104 );
105 }
106
107 #[test]
108 fn rejects_when_request_limit_would_be_exceeded() {
109 let mut meter = UsageMeter::new();
110 meter.record_request("s1", &cfg(), 10, 10);
111 meter.record_request("s1", &cfg(), 10, 10);
112 let verdict = meter.preview_request("s1", &cfg(), 1, 1);
113 assert_eq!(
114 verdict,
115 MeterVerdict::Reject(MeterRejectReason::SessionRequestLimitExceeded)
116 );
117 }
118
119 #[test]
120 fn estimate_cost_uses_separate_input_and_output_prices() {
121 let c = cfg();
122 let cost = estimate_cost(&c, 1000, 500);
123 assert!((cost - 0.025).abs() < 1e-9);
124 }
125}