1use latch_core::{MeterConfig, MeterRejectReason, MeterVerdict, SessionUsage};
2use std::collections::HashMap;
3
4#[derive(Default)]
5pub struct UsageMeter {
6 sessions: HashMap<String, SessionUsage>,
7}
8
9impl UsageMeter {
10 pub fn new() -> Self {
11 Self::default()
12 }
13
14 pub fn get_session_usage(&self, session_id: &str) -> Option<&SessionUsage> {
15 self.sessions.get(session_id)
16 }
17
18 pub fn preview_request(
19 &self,
20 session_id: &str,
21 cfg: &MeterConfig,
22 predicted_input_tokens: u64,
23 predicted_output_tokens: u64,
24 ) -> MeterVerdict {
25 let current = self.sessions.get(session_id).cloned().unwrap_or_default();
26 let next_input = current.input_tokens.saturating_add(predicted_input_tokens);
27 let next_output = current
28 .output_tokens
29 .saturating_add(predicted_output_tokens);
30 let next_requests = current.requests.saturating_add(1);
31
32 if let Some(limit) = cfg.session_token_limit {
33 if next_input.saturating_add(next_output) > limit {
34 return MeterVerdict::Reject(MeterRejectReason::SessionTokenLimitExceeded);
35 }
36 }
37
38 if let Some(limit) = cfg.session_request_limit {
39 if next_requests > limit {
40 return MeterVerdict::Reject(MeterRejectReason::SessionRequestLimitExceeded);
41 }
42 }
43
44 MeterVerdict::Allow
45 }
46
47 pub fn record_request(
48 &mut self,
49 session_id: &str,
50 cfg: &MeterConfig,
51 input_tokens: u64,
52 output_tokens: u64,
53 ) -> SessionUsage {
54 let entry = self.sessions.entry(session_id.to_string()).or_default();
55 entry.input_tokens = entry.input_tokens.saturating_add(input_tokens);
56 entry.output_tokens = entry.output_tokens.saturating_add(output_tokens);
57 entry.requests = entry.requests.saturating_add(1);
58 entry.estimated_cost = estimate_cost(cfg, entry.input_tokens, entry.output_tokens);
59 entry.clone()
60 }
61}
62
63pub fn estimate_cost(cfg: &MeterConfig, input_tokens: u64, output_tokens: u64) -> f64 {
64 let input_cost = (input_tokens as f64 / 1000.0) * cfg.price_per_1k_input_tokens;
65 let output_cost = (output_tokens as f64 / 1000.0) * cfg.price_per_1k_output_tokens;
66 input_cost + output_cost
67}
68
69#[cfg(test)]
70mod tests {
71 use super::{estimate_cost, UsageMeter};
72 use latch_core::{MeterConfig, MeterRejectReason, MeterVerdict};
73
74 fn cfg() -> MeterConfig {
75 MeterConfig {
76 session_token_limit: Some(1000),
77 session_request_limit: Some(2),
78 price_per_1k_input_tokens: 0.01,
79 price_per_1k_output_tokens: 0.03,
80 currency: "USD".to_string(),
81 }
82 }
83
84 #[test]
85 fn tracks_usage_and_cost() {
86 let mut meter = UsageMeter::new();
87 let usage = meter.record_request("s1", &cfg(), 200, 100);
88 assert_eq!(usage.input_tokens, 200);
89 assert_eq!(usage.output_tokens, 100);
90 assert_eq!(usage.requests, 1);
91 assert!((usage.estimated_cost - 0.005).abs() < 1e-9);
92 }
93
94 #[test]
95 fn rejects_when_token_limit_would_be_exceeded() {
96 let mut meter = UsageMeter::new();
97 meter.record_request("s1", &cfg(), 800, 100);
98 let verdict = meter.preview_request("s1", &cfg(), 80, 50);
99 assert_eq!(
100 verdict,
101 MeterVerdict::Reject(MeterRejectReason::SessionTokenLimitExceeded)
102 );
103 }
104
105 #[test]
106 fn rejects_when_request_limit_would_be_exceeded() {
107 let mut meter = UsageMeter::new();
108 meter.record_request("s1", &cfg(), 10, 10);
109 meter.record_request("s1", &cfg(), 10, 10);
110 let verdict = meter.preview_request("s1", &cfg(), 1, 1);
111 assert_eq!(
112 verdict,
113 MeterVerdict::Reject(MeterRejectReason::SessionRequestLimitExceeded)
114 );
115 }
116
117 #[test]
118 fn estimate_cost_uses_separate_input_and_output_prices() {
119 let c = cfg();
120 let cost = estimate_cost(&c, 1000, 500);
121 assert!((cost - 0.025).abs() < 1e-9);
122 }
123}