1use async_trait::async_trait;
45use std::fmt;
46use std::sync::atomic::{AtomicU64, Ordering};
47use std::sync::Arc;
48use tracing::{debug, info, trace};
49
50use crate::error::Result;
51use crate::model_config::ModelCost;
52use crate::traits::{ChatMessage, CompletionOptions, LLMResponse, ToolDefinition};
53
54#[derive(Debug, Clone)]
60pub struct LLMRequest {
61 pub messages: Vec<ChatMessage>,
63
64 pub tools: Option<Vec<ToolDefinition>>,
66
67 pub options: Option<CompletionOptions>,
69
70 pub provider: String,
72
73 pub model: String,
75}
76
77impl LLMRequest {
78 pub fn new(
80 messages: Vec<ChatMessage>,
81 provider: impl Into<String>,
82 model: impl Into<String>,
83 ) -> Self {
84 Self {
85 messages,
86 tools: None,
87 options: None,
88 provider: provider.into(),
89 model: model.into(),
90 }
91 }
92
93 pub fn with_tools(mut self, tools: Vec<ToolDefinition>) -> Self {
95 self.tools = Some(tools);
96 self
97 }
98
99 pub fn with_options(mut self, options: CompletionOptions) -> Self {
101 self.options = Some(options);
102 self
103 }
104
105 pub fn message_count(&self) -> usize {
107 self.messages.len()
108 }
109
110 pub fn tool_count(&self) -> usize {
112 self.tools.as_ref().map(|t| t.len()).unwrap_or(0)
113 }
114}
115
116#[async_trait]
126pub trait LLMMiddleware: Send + Sync {
127 fn name(&self) -> &str;
129
130 async fn before(&self, request: &LLMRequest) -> Result<()> {
139 let _ = request;
140 Ok(())
141 }
142
143 async fn after(
150 &self,
151 request: &LLMRequest,
152 response: &LLMResponse,
153 duration_ms: u64,
154 ) -> Result<()> {
155 let _ = (request, response, duration_ms);
156 Ok(())
157 }
158}
159
160#[derive(Default)]
166pub struct LLMMiddlewareStack {
167 middlewares: Vec<Arc<dyn LLMMiddleware>>,
168}
169
170impl LLMMiddlewareStack {
171 pub fn new() -> Self {
173 Self {
174 middlewares: Vec::new(),
175 }
176 }
177
178 pub fn add(&mut self, middleware: Arc<dyn LLMMiddleware>) {
180 self.middlewares.push(middleware);
181 }
182
183 pub fn len(&self) -> usize {
185 self.middlewares.len()
186 }
187
188 pub fn is_empty(&self) -> bool {
190 self.middlewares.is_empty()
191 }
192
193 pub async fn before(&self, request: &LLMRequest) -> Result<()> {
195 for middleware in &self.middlewares {
196 middleware.before(request).await?;
197 }
198 Ok(())
199 }
200
201 pub async fn after(
203 &self,
204 request: &LLMRequest,
205 response: &LLMResponse,
206 duration_ms: u64,
207 ) -> Result<()> {
208 for middleware in self.middlewares.iter().rev() {
209 middleware.after(request, response, duration_ms).await?;
210 }
211 Ok(())
212 }
213}
214
215pub struct LoggingLLMMiddleware {
221 log_level: LogLevel,
223}
224
225#[derive(Debug, Clone, Copy, Default)]
227pub enum LogLevel {
228 #[default]
230 Info,
231 Debug,
233 Trace,
235}
236
237impl LoggingLLMMiddleware {
238 pub fn new() -> Self {
240 Self {
241 log_level: LogLevel::Info,
242 }
243 }
244
245 pub fn with_level(level: LogLevel) -> Self {
247 Self { log_level: level }
248 }
249}
250
251impl Default for LoggingLLMMiddleware {
252 fn default() -> Self {
253 Self::new()
254 }
255}
256
257#[async_trait]
258impl LLMMiddleware for LoggingLLMMiddleware {
259 fn name(&self) -> &str {
260 "logging"
261 }
262
263 async fn before(&self, request: &LLMRequest) -> Result<()> {
264 match self.log_level {
265 LogLevel::Info => {
266 info!(
267 provider = %request.provider,
268 model = %request.model,
269 messages = request.message_count(),
270 tools = request.tool_count(),
271 "[LLM] Request"
272 );
273 }
274 LogLevel::Debug => {
275 let last_msg = request.messages.last().map(|m| {
276 let preview = if m.content.chars().count() > 100 {
277 let truncated: String = m.content.chars().take(97).collect();
278 format!("{}...", truncated)
279 } else {
280 m.content.clone()
281 };
282 format!("[{:?}] {}", m.role, preview)
283 });
284 debug!(
285 provider = %request.provider,
286 model = %request.model,
287 messages = request.message_count(),
288 tools = request.tool_count(),
289 last_message = ?last_msg,
290 "[LLM] Request"
291 );
292 }
293 LogLevel::Trace => {
294 trace!(
295 provider = %request.provider,
296 model = %request.model,
297 messages = ?request.messages,
298 "[LLM] Full request"
299 );
300 }
301 }
302 Ok(())
303 }
304
305 async fn after(
306 &self,
307 request: &LLMRequest,
308 response: &LLMResponse,
309 duration_ms: u64,
310 ) -> Result<()> {
311 match self.log_level {
312 LogLevel::Info => {
313 info!(
314 model = %request.model,
315 tokens = response.total_tokens,
316 duration_ms = duration_ms,
317 finish_reason = ?response.finish_reason,
318 "[LLM] Response"
319 );
320 }
321 LogLevel::Debug => {
322 let preview = if response.content.chars().count() > 200 {
323 let truncated: String = response.content.chars().take(197).collect();
324 format!("{}...", truncated)
325 } else {
326 response.content.clone()
327 };
328 debug!(
329 model = %request.model,
330 tokens = response.total_tokens,
331 duration_ms = duration_ms,
332 tool_calls = response.tool_calls.len(),
333 content_preview = %preview,
334 "[LLM] Response"
335 );
336 }
337 LogLevel::Trace => {
338 trace!(
339 model = %request.model,
340 response = ?response,
341 "[LLM] Full response"
342 );
343 }
344 }
345 Ok(())
346 }
347}
348
349pub struct MetricsLLMMiddleware {
355 pub total_requests: AtomicU64,
357 pub total_tokens: AtomicU64,
359 pub prompt_tokens: AtomicU64,
361 pub completion_tokens: AtomicU64,
363 pub total_time_ms: AtomicU64,
365 pub tool_call_requests: AtomicU64,
367 pub cache_hit_tokens: AtomicU64,
371 pub requests_with_cache: AtomicU64,
374}
375
376impl MetricsLLMMiddleware {
377 pub fn new() -> Self {
379 Self {
380 total_requests: AtomicU64::new(0),
381 total_tokens: AtomicU64::new(0),
382 prompt_tokens: AtomicU64::new(0),
383 completion_tokens: AtomicU64::new(0),
384 total_time_ms: AtomicU64::new(0),
385 tool_call_requests: AtomicU64::new(0),
386 cache_hit_tokens: AtomicU64::new(0),
387 requests_with_cache: AtomicU64::new(0),
388 }
389 }
390
391 pub fn get_total_requests(&self) -> u64 {
393 self.total_requests.load(Ordering::Relaxed)
394 }
395
396 pub fn get_total_tokens(&self) -> u64 {
398 self.total_tokens.load(Ordering::Relaxed)
399 }
400
401 pub fn get_average_latency_ms(&self) -> f64 {
403 let requests = self.total_requests.load(Ordering::Relaxed);
404 if requests == 0 {
405 0.0
406 } else {
407 self.total_time_ms.load(Ordering::Relaxed) as f64 / requests as f64
408 }
409 }
410
411 pub fn get_cache_hit_tokens(&self) -> u64 {
413 self.cache_hit_tokens.load(Ordering::Relaxed)
414 }
415
416 pub fn get_cache_hit_rate(&self) -> f64 {
422 let prompt = self.prompt_tokens.load(Ordering::Relaxed);
423 if prompt == 0 {
424 0.0
425 } else {
426 (self.cache_hit_tokens.load(Ordering::Relaxed) as f64 / prompt as f64) * 100.0
427 }
428 }
429
430 pub fn get_summary(&self) -> MetricsSummary {
432 MetricsSummary {
433 total_requests: self.total_requests.load(Ordering::Relaxed),
434 total_tokens: self.total_tokens.load(Ordering::Relaxed),
435 prompt_tokens: self.prompt_tokens.load(Ordering::Relaxed),
436 completion_tokens: self.completion_tokens.load(Ordering::Relaxed),
437 total_time_ms: self.total_time_ms.load(Ordering::Relaxed),
438 tool_call_requests: self.tool_call_requests.load(Ordering::Relaxed),
439 cache_hit_tokens: self.cache_hit_tokens.load(Ordering::Relaxed),
440 requests_with_cache: self.requests_with_cache.load(Ordering::Relaxed),
441 }
442 }
443}
444
445impl Default for MetricsLLMMiddleware {
446 fn default() -> Self {
447 Self::new()
448 }
449}
450
451#[derive(Debug, Clone, Default)]
455pub struct MetricsSummary {
456 pub total_requests: u64,
458 pub total_tokens: u64,
460 pub prompt_tokens: u64,
462 pub completion_tokens: u64,
464 pub total_time_ms: u64,
466 pub tool_call_requests: u64,
468 pub cache_hit_tokens: u64,
470 pub requests_with_cache: u64,
472}
473
474#[derive(Debug, Clone, Default)]
491pub struct MetricsSummaryBuilder {
492 inner: MetricsSummary,
493}
494
495impl MetricsSummaryBuilder {
496 pub fn new() -> Self {
498 Self::default()
499 }
500
501 pub fn requests(mut self, n: u64) -> Self {
503 self.inner.total_requests = n;
504 self
505 }
506
507 pub fn prompt_tokens(mut self, n: u64) -> Self {
509 self.inner.prompt_tokens = n;
510 self.inner.total_tokens = self.inner.prompt_tokens + self.inner.completion_tokens;
511 self
512 }
513
514 pub fn completion_tokens(mut self, n: u64) -> Self {
516 self.inner.completion_tokens = n;
517 self.inner.total_tokens = self.inner.prompt_tokens + self.inner.completion_tokens;
518 self
519 }
520
521 pub fn cache_hit_tokens(mut self, n: u64) -> Self {
523 self.inner.cache_hit_tokens = n;
524 self.inner.requests_with_cache = if n > 0 { 1 } else { 0 };
525 self
526 }
527
528 pub fn time_ms(mut self, ms: u64) -> Self {
530 self.inner.total_time_ms = ms;
531 self
532 }
533
534 pub fn tool_calls(mut self, n: u64) -> Self {
536 self.inner.tool_call_requests = n;
537 self
538 }
539
540 pub fn requests_with_cache(mut self, n: u64) -> Self {
542 self.inner.requests_with_cache = n;
543 self
544 }
545
546 pub fn build(self) -> MetricsSummary {
548 self.inner
549 }
550}
551
552impl MetricsSummary {
553 pub fn average_latency_ms(&self) -> f64 {
555 if self.total_requests == 0 {
556 0.0
557 } else {
558 self.total_time_ms as f64 / self.total_requests as f64
559 }
560 }
561
562 pub fn average_tokens_per_request(&self) -> f64 {
564 if self.total_requests == 0 {
565 0.0
566 } else {
567 self.total_tokens as f64 / self.total_requests as f64
568 }
569 }
570
571 pub fn tokens_per_second(&self) -> f64 {
581 if self.total_time_ms == 0 {
582 return 0.0;
583 }
584 (self.total_tokens as f64) / (self.total_time_ms as f64 / 1000.0)
585 }
586
587 pub fn output_tokens_per_second(&self) -> f64 {
597 if self.total_time_ms == 0 {
598 return 0.0;
599 }
600 (self.completion_tokens as f64) / (self.total_time_ms as f64 / 1000.0)
601 }
602
603 pub fn token_efficiency(&self) -> f64 {
620 if self.prompt_tokens == 0 {
621 return 0.0;
622 }
623 (self.completion_tokens as f64 / self.prompt_tokens as f64) * 100.0
624 }
625
626 pub fn cache_hit_rate(&self) -> f64 {
638 if self.prompt_tokens == 0 {
639 0.0
640 } else {
641 (self.cache_hit_tokens as f64 / self.prompt_tokens as f64) * 100.0
642 }
643 }
644
645 pub fn cache_utilization(&self) -> f64 {
650 if self.total_requests == 0 {
651 0.0
652 } else {
653 (self.requests_with_cache as f64 / self.total_requests as f64) * 100.0
654 }
655 }
656
657 pub fn estimated_savings(&self, cost_per_1k_prompt: f64) -> f64 {
676 let savings_rate = 0.9;
678 (self.cache_hit_tokens as f64 / 1000.0) * cost_per_1k_prompt * savings_rate
679 }
680
681 pub fn estimated_cost(&self, cost_per_1k_prompt: f64, cost_per_1k_completion: f64) -> f64 {
692 let uncached_prompt = self.prompt_tokens.saturating_sub(self.cache_hit_tokens);
694
695 let cached_cost = (self.cache_hit_tokens as f64 / 1000.0) * cost_per_1k_prompt * 0.1;
697 let uncached_cost = (uncached_prompt as f64 / 1000.0) * cost_per_1k_prompt;
698 let completion_cost = (self.completion_tokens as f64 / 1000.0) * cost_per_1k_completion;
699
700 cached_cost + uncached_cost + completion_cost
701 }
702
703 pub fn estimated_savings_for_model(&self, cost: &ModelCost) -> f64 {
714 self.estimated_savings(cost.input_per_1k)
715 }
716
717 pub fn estimated_cost_for_model(&self, cost: &ModelCost) -> f64 {
728 self.estimated_cost(cost.input_per_1k, cost.output_per_1k)
729 }
730}
731
732impl fmt::Display for MetricsSummary {
740 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
741 write!(
742 f,
743 "reqs={} tokens={}/{}/{} cache={:.1}% latency={:.0}ms tps={:.1}",
744 self.total_requests,
745 self.prompt_tokens,
746 self.completion_tokens,
747 self.cache_hit_tokens,
748 self.cache_hit_rate(),
749 self.average_latency_ms(),
750 self.output_tokens_per_second()
751 )
752 }
753}
754
755#[async_trait]
756impl LLMMiddleware for MetricsLLMMiddleware {
757 fn name(&self) -> &str {
758 "metrics"
759 }
760
761 async fn before(&self, _request: &LLMRequest) -> Result<()> {
762 self.total_requests.fetch_add(1, Ordering::Relaxed);
763 Ok(())
764 }
765
766 async fn after(
767 &self,
768 _request: &LLMRequest,
769 response: &LLMResponse,
770 duration_ms: u64,
771 ) -> Result<()> {
772 self.total_tokens
773 .fetch_add(response.total_tokens as u64, Ordering::Relaxed);
774 self.prompt_tokens
775 .fetch_add(response.prompt_tokens as u64, Ordering::Relaxed);
776 self.completion_tokens
777 .fetch_add(response.completion_tokens as u64, Ordering::Relaxed);
778 self.total_time_ms.fetch_add(duration_ms, Ordering::Relaxed);
779
780 if !response.tool_calls.is_empty() {
781 self.tool_call_requests.fetch_add(1, Ordering::Relaxed);
782 }
783
784 if let Some(cache_hits) = response.cache_hit_tokens {
788 self.cache_hit_tokens
789 .fetch_add(cache_hits as u64, Ordering::Relaxed);
790 if cache_hits > 0 {
791 self.requests_with_cache.fetch_add(1, Ordering::Relaxed);
792 }
793 }
794
795 Ok(())
796 }
797}
798
799#[cfg(test)]
804mod tests {
805 use super::*;
806
807 fn create_test_request() -> LLMRequest {
808 LLMRequest::new(
809 vec![ChatMessage::user("Hello")],
810 "test-provider",
811 "test-model",
812 )
813 }
814
815 fn create_test_response() -> LLMResponse {
816 LLMResponse::new("Hello back!", "test-model").with_usage(10, 5)
817 }
818
819 #[tokio::test]
820 async fn test_empty_middleware_stack() {
821 let stack = LLMMiddlewareStack::new();
822 assert!(stack.is_empty());
823
824 let request = create_test_request();
825 let response = create_test_response();
826
827 stack.before(&request).await.unwrap();
829 stack.after(&request, &response, 100).await.unwrap();
830 }
831
832 #[tokio::test]
833 async fn test_middleware_stack_with_logging() {
834 let mut stack = LLMMiddlewareStack::new();
835 stack.add(Arc::new(LoggingLLMMiddleware::new()));
836
837 assert_eq!(stack.len(), 1);
838
839 let request = create_test_request();
840 let response = create_test_response();
841
842 stack.before(&request).await.unwrap();
843 stack.after(&request, &response, 100).await.unwrap();
844 }
845
846 #[tokio::test]
847 async fn test_metrics_middleware() {
848 let metrics = Arc::new(MetricsLLMMiddleware::new());
849 let mut stack = LLMMiddlewareStack::new();
850 stack.add(metrics.clone());
851
852 let request = create_test_request();
853 let response = create_test_response();
854
855 stack.before(&request).await.unwrap();
856 stack.after(&request, &response, 150).await.unwrap();
857
858 assert_eq!(metrics.get_total_requests(), 1);
859 assert_eq!(metrics.get_total_tokens(), 15);
860 assert_eq!(metrics.get_average_latency_ms(), 150.0);
861 }
862
863 #[tokio::test]
864 async fn test_multiple_middlewares() {
865 let metrics = Arc::new(MetricsLLMMiddleware::new());
866 let mut stack = LLMMiddlewareStack::new();
867 stack.add(Arc::new(LoggingLLMMiddleware::new()));
868 stack.add(metrics.clone());
869
870 assert_eq!(stack.len(), 2);
871
872 let request = create_test_request();
873 let response = create_test_response();
874
875 stack.before(&request).await.unwrap();
876 stack.after(&request, &response, 200).await.unwrap();
877
878 assert_eq!(metrics.get_total_requests(), 1);
879 }
880
881 #[tokio::test]
882 async fn test_metrics_summary() {
883 let metrics = MetricsLLMMiddleware::new();
884
885 let request = create_test_request();
886 let response = create_test_response();
887
888 metrics.before(&request).await.unwrap();
889 metrics.after(&request, &response, 100).await.unwrap();
890
891 metrics.before(&request).await.unwrap();
892 metrics.after(&request, &response, 200).await.unwrap();
893
894 let summary = metrics.get_summary();
895 assert_eq!(summary.total_requests, 2);
896 assert_eq!(summary.total_tokens, 30);
897 assert_eq!(summary.average_latency_ms(), 150.0);
898 assert_eq!(summary.average_tokens_per_request(), 15.0);
899 }
900
901 #[test]
902 fn test_llm_request_builder() {
903 let request = LLMRequest::new(vec![ChatMessage::user("Test")], "openai", "gpt-4")
904 .with_options(CompletionOptions::with_temperature(0.7));
905
906 assert_eq!(request.provider, "openai");
907 assert_eq!(request.model, "gpt-4");
908 assert_eq!(request.message_count(), 1);
909 assert_eq!(request.tool_count(), 0);
910 assert!(request.options.is_some());
911 }
912
913 #[tokio::test]
918 async fn test_cache_metrics_tracking() {
919 let metrics = Arc::new(MetricsLLMMiddleware::new());
920 let mut stack = LLMMiddlewareStack::new();
921 stack.add(metrics.clone());
922
923 let request = create_test_request();
924
925 let response = LLMResponse::new("Hello", "test-model")
927 .with_usage(100, 20)
928 .with_cache_hit_tokens(80);
929
930 stack.before(&request).await.unwrap();
931 stack.after(&request, &response, 100).await.unwrap();
932
933 assert_eq!(metrics.get_cache_hit_tokens(), 80);
934 assert_eq!(metrics.get_summary().requests_with_cache, 1);
935
936 let rate = metrics.get_cache_hit_rate();
938 assert!((rate - 80.0).abs() < 0.01);
939 }
940
941 #[tokio::test]
942 async fn test_cache_metrics_none() {
943 let metrics = Arc::new(MetricsLLMMiddleware::new());
944 let mut stack = LLMMiddlewareStack::new();
945 stack.add(metrics.clone());
946
947 let request = create_test_request();
948
949 let response = LLMResponse::new("Hello", "test-model").with_usage(100, 20);
951
952 stack.before(&request).await.unwrap();
953 stack.after(&request, &response, 100).await.unwrap();
954
955 assert_eq!(metrics.get_cache_hit_tokens(), 0);
956 assert_eq!(metrics.get_summary().requests_with_cache, 0);
957 assert_eq!(metrics.get_cache_hit_rate(), 0.0);
958 }
959
960 #[tokio::test]
961 async fn test_cache_metrics_zero_hits() {
962 let metrics = Arc::new(MetricsLLMMiddleware::new());
963 let mut stack = LLMMiddlewareStack::new();
964 stack.add(metrics.clone());
965
966 let request = create_test_request();
967
968 let response = LLMResponse::new("Hello", "test-model")
970 .with_usage(100, 20)
971 .with_cache_hit_tokens(0);
972
973 stack.before(&request).await.unwrap();
974 stack.after(&request, &response, 100).await.unwrap();
975
976 assert_eq!(metrics.get_cache_hit_tokens(), 0);
977 assert_eq!(metrics.get_summary().requests_with_cache, 0);
979 }
980
981 #[tokio::test]
982 async fn test_cache_metrics_aggregation() {
983 let metrics = Arc::new(MetricsLLMMiddleware::new());
984 let mut stack = LLMMiddlewareStack::new();
985 stack.add(metrics.clone());
986
987 let request = create_test_request();
988
989 let response1 = LLMResponse::new("Hello", "test-model")
991 .with_usage(100, 20)
992 .with_cache_hit_tokens(80);
993
994 stack.before(&request).await.unwrap();
995 stack.after(&request, &response1, 100).await.unwrap();
996
997 let response2 = LLMResponse::new("World", "test-model")
999 .with_usage(200, 50)
1000 .with_cache_hit_tokens(150);
1001
1002 stack.before(&request).await.unwrap();
1003 stack.after(&request, &response2, 100).await.unwrap();
1004
1005 let response3 = LLMResponse::new("Test", "test-model").with_usage(100, 30);
1007
1008 stack.before(&request).await.unwrap();
1009 stack.after(&request, &response3, 100).await.unwrap();
1010
1011 let summary = metrics.get_summary();
1012 assert_eq!(summary.total_requests, 3);
1013 assert_eq!(summary.prompt_tokens, 400); assert_eq!(summary.cache_hit_tokens, 230); assert_eq!(summary.requests_with_cache, 2);
1016
1017 assert!((summary.cache_hit_rate() - 57.5).abs() < 0.01);
1019
1020 assert!((summary.cache_utilization() - 66.67).abs() < 0.1);
1022 }
1023
1024 #[test]
1025 fn test_cache_hit_rate_calculation() {
1026 let summary = MetricsSummary {
1027 total_requests: 10,
1028 total_tokens: 1500,
1029 prompt_tokens: 1000,
1030 completion_tokens: 500,
1031 total_time_ms: 5000,
1032 tool_call_requests: 5,
1033 cache_hit_tokens: 800,
1034 requests_with_cache: 8,
1035 };
1036
1037 assert!((summary.cache_hit_rate() - 80.0).abs() < 0.01);
1039
1040 assert!((summary.cache_utilization() - 80.0).abs() < 0.01);
1042 }
1043
1044 #[test]
1045 fn test_cache_hit_rate_zero_prompts() {
1046 let summary = MetricsSummary {
1047 total_requests: 0,
1048 total_tokens: 0,
1049 prompt_tokens: 0,
1050 completion_tokens: 0,
1051 total_time_ms: 0,
1052 tool_call_requests: 0,
1053 cache_hit_tokens: 0,
1054 requests_with_cache: 0,
1055 };
1056
1057 assert_eq!(summary.cache_hit_rate(), 0.0);
1059 assert_eq!(summary.cache_utilization(), 0.0);
1060 }
1061
1062 #[test]
1067 fn test_estimated_savings_calculation() {
1068 let summary = MetricsSummary {
1069 total_requests: 5,
1070 total_tokens: 1500,
1071 prompt_tokens: 1000,
1072 completion_tokens: 500,
1073 total_time_ms: 5000,
1074 tool_call_requests: 2,
1075 cache_hit_tokens: 800,
1076 requests_with_cache: 4,
1077 };
1078
1079 let savings = summary.estimated_savings(0.003);
1082 assert!((savings - 0.00216).abs() < 0.00001);
1083 }
1084
1085 #[test]
1086 fn test_estimated_savings_zero_cache() {
1087 let summary = MetricsSummary {
1088 total_requests: 5,
1089 total_tokens: 1500,
1090 prompt_tokens: 1000,
1091 completion_tokens: 500,
1092 total_time_ms: 5000,
1093 tool_call_requests: 2,
1094 cache_hit_tokens: 0,
1095 requests_with_cache: 0,
1096 };
1097
1098 assert_eq!(summary.estimated_savings(0.003), 0.0);
1100 }
1101
1102 #[test]
1103 fn test_estimated_cost_all_cached() {
1104 let summary = MetricsSummary {
1105 total_requests: 1,
1106 total_tokens: 2000,
1107 prompt_tokens: 1000,
1108 completion_tokens: 1000,
1109 total_time_ms: 100,
1110 tool_call_requests: 0,
1111 cache_hit_tokens: 1000, requests_with_cache: 1,
1113 };
1114
1115 let cost = summary.estimated_cost(0.003, 0.015);
1121 assert!((cost - 0.0153).abs() < 0.00001);
1122 }
1123
1124 #[test]
1125 fn test_estimated_cost_no_cache() {
1126 let summary = MetricsSummary {
1127 total_requests: 1,
1128 total_tokens: 2000,
1129 prompt_tokens: 1000,
1130 completion_tokens: 1000,
1131 total_time_ms: 100,
1132 tool_call_requests: 0,
1133 cache_hit_tokens: 0,
1134 requests_with_cache: 0,
1135 };
1136
1137 let cost = summary.estimated_cost(0.003, 0.015);
1143 assert!((cost - 0.018).abs() < 0.00001);
1144 }
1145
1146 #[test]
1147 fn test_estimated_cost_partial_cache() {
1148 let summary = MetricsSummary {
1149 total_requests: 1,
1150 total_tokens: 2000,
1151 prompt_tokens: 1000,
1152 completion_tokens: 1000,
1153 total_time_ms: 100,
1154 tool_call_requests: 0,
1155 cache_hit_tokens: 500, requests_with_cache: 1,
1157 };
1158
1159 let cost = summary.estimated_cost(0.003, 0.015);
1165 assert!((cost - 0.01665).abs() < 0.00001);
1166 }
1167
1168 #[test]
1173 fn test_estimated_savings_for_model() {
1174 use crate::model_config::ModelCost;
1175
1176 let summary = MetricsSummary {
1177 total_requests: 5,
1178 total_tokens: 1500,
1179 prompt_tokens: 1000,
1180 completion_tokens: 500,
1181 total_time_ms: 5000,
1182 tool_call_requests: 2,
1183 cache_hit_tokens: 800,
1184 requests_with_cache: 4,
1185 };
1186
1187 let model_cost = ModelCost {
1188 input_per_1k: 0.003,
1189 output_per_1k: 0.015,
1190 embedding_per_1k: 0.0,
1191 image_per_unit: 0.0,
1192 currency: "USD".to_string(),
1193 };
1194
1195 let savings_direct = summary.estimated_savings(0.003);
1197 let savings_model = summary.estimated_savings_for_model(&model_cost);
1198 assert!((savings_direct - savings_model).abs() < 0.00001);
1199 }
1200
1201 #[test]
1202 fn test_estimated_cost_for_model() {
1203 use crate::model_config::ModelCost;
1204
1205 let summary = MetricsSummary {
1206 total_requests: 1,
1207 total_tokens: 2000,
1208 prompt_tokens: 1000,
1209 completion_tokens: 1000,
1210 total_time_ms: 100,
1211 tool_call_requests: 0,
1212 cache_hit_tokens: 500,
1213 requests_with_cache: 1,
1214 };
1215
1216 let model_cost = ModelCost {
1217 input_per_1k: 0.003,
1218 output_per_1k: 0.015,
1219 embedding_per_1k: 0.0,
1220 image_per_unit: 0.0,
1221 currency: "USD".to_string(),
1222 };
1223
1224 let cost_direct = summary.estimated_cost(0.003, 0.015);
1226 let cost_model = summary.estimated_cost_for_model(&model_cost);
1227 assert!((cost_direct - cost_model).abs() < 0.00001);
1228 }
1229
1230 #[test]
1231 fn test_model_cost_gpt4_pricing() {
1232 use crate::model_config::ModelCost;
1233
1234 let summary = MetricsSummary {
1235 total_requests: 10,
1236 total_tokens: 50000,
1237 prompt_tokens: 40000,
1238 completion_tokens: 10000,
1239 total_time_ms: 30000,
1240 tool_call_requests: 5,
1241 cache_hit_tokens: 35000, requests_with_cache: 9,
1243 };
1244
1245 let gpt4_cost = ModelCost {
1247 input_per_1k: 0.0025, output_per_1k: 0.01, embedding_per_1k: 0.0,
1250 image_per_unit: 0.0,
1251 currency: "USD".to_string(),
1252 };
1253
1254 let savings = summary.estimated_savings_for_model(&gpt4_cost);
1256 assert!((savings - 0.07875).abs() < 0.0001);
1257
1258 let cost = summary.estimated_cost_for_model(&gpt4_cost);
1264 assert!((cost - 0.12125).abs() < 0.0001);
1265 }
1266
1267 #[test]
1272 fn test_metrics_summary_display() {
1273 let summary = MetricsSummary {
1274 total_requests: 10,
1275 total_tokens: 5000,
1276 prompt_tokens: 4000,
1277 completion_tokens: 1000,
1278 total_time_ms: 1500,
1279 tool_call_requests: 3,
1280 cache_hit_tokens: 3200,
1281 requests_with_cache: 8,
1282 };
1283
1284 let display = format!("{}", summary);
1285
1286 assert!(display.contains("reqs=10"));
1288 assert!(display.contains("tokens=4000/1000/3200"));
1289 assert!(display.contains("cache=80.0%")); assert!(display.contains("latency=150ms")); assert!(display.contains("tps=")); }
1293
1294 #[test]
1295 fn test_metrics_summary_display_zero_values() {
1296 let summary = MetricsSummary {
1297 total_requests: 0,
1298 total_tokens: 0,
1299 prompt_tokens: 0,
1300 completion_tokens: 0,
1301 total_time_ms: 0,
1302 tool_call_requests: 0,
1303 cache_hit_tokens: 0,
1304 requests_with_cache: 0,
1305 };
1306
1307 let display = format!("{}", summary);
1308
1309 assert!(display.contains("reqs=0"));
1311 assert!(display.contains("cache=0.0%"));
1312 assert!(display.contains("tps=0.0"));
1313 }
1314
1315 #[test]
1320 fn test_tokens_per_second() {
1321 let summary = MetricsSummary {
1322 total_requests: 5,
1323 total_tokens: 10000, prompt_tokens: 8000,
1325 completion_tokens: 2000,
1326 total_time_ms: 2000, tool_call_requests: 0,
1328 cache_hit_tokens: 0,
1329 requests_with_cache: 0,
1330 };
1331
1332 assert!((summary.tokens_per_second() - 5000.0).abs() < 0.1);
1334 }
1335
1336 #[test]
1337 fn test_output_tokens_per_second() {
1338 let summary = MetricsSummary {
1339 total_requests: 5,
1340 total_tokens: 10000,
1341 prompt_tokens: 8000,
1342 completion_tokens: 2000, total_time_ms: 2000, tool_call_requests: 0,
1345 cache_hit_tokens: 0,
1346 requests_with_cache: 0,
1347 };
1348
1349 assert!((summary.output_tokens_per_second() - 1000.0).abs() < 0.1);
1351 }
1352
1353 #[test]
1354 fn test_throughput_zero_time() {
1355 let summary = MetricsSummary {
1356 total_requests: 1,
1357 total_tokens: 1000,
1358 prompt_tokens: 800,
1359 completion_tokens: 200,
1360 total_time_ms: 0, tool_call_requests: 0,
1362 cache_hit_tokens: 0,
1363 requests_with_cache: 0,
1364 };
1365
1366 assert_eq!(summary.tokens_per_second(), 0.0);
1368 assert_eq!(summary.output_tokens_per_second(), 0.0);
1369 }
1370
1371 #[test]
1372 fn test_throughput_realistic_session() {
1373 let summary = MetricsSummary {
1375 total_requests: 20,
1376 total_tokens: 50000, prompt_tokens: 40000, completion_tokens: 10000, total_time_ms: 30000, tool_call_requests: 15,
1381 cache_hit_tokens: 35000, requests_with_cache: 18,
1383 };
1384
1385 assert!((summary.tokens_per_second() - 1666.67).abs() < 1.0);
1387
1388 assert!((summary.output_tokens_per_second() - 333.33).abs() < 1.0);
1390 }
1391
1392 #[test]
1397 fn test_token_efficiency() {
1398 let summary = MetricsSummary {
1399 total_requests: 5,
1400 total_tokens: 5000,
1401 prompt_tokens: 4000,
1402 completion_tokens: 1000, total_time_ms: 5000,
1404 tool_call_requests: 2,
1405 cache_hit_tokens: 0,
1406 requests_with_cache: 0,
1407 };
1408
1409 assert!((summary.token_efficiency() - 25.0).abs() < 0.01);
1411 }
1412
1413 #[test]
1414 fn test_token_efficiency_low() {
1415 let summary = MetricsSummary {
1416 total_requests: 1,
1417 total_tokens: 10100,
1418 prompt_tokens: 10000, completion_tokens: 100, total_time_ms: 1000,
1421 tool_call_requests: 0,
1422 cache_hit_tokens: 0,
1423 requests_with_cache: 0,
1424 };
1425
1426 assert!((summary.token_efficiency() - 1.0).abs() < 0.01);
1428 }
1429
1430 #[test]
1431 fn test_token_efficiency_zero_prompt() {
1432 let summary = MetricsSummary {
1433 total_requests: 0,
1434 total_tokens: 0,
1435 prompt_tokens: 0, completion_tokens: 0,
1437 total_time_ms: 0,
1438 tool_call_requests: 0,
1439 cache_hit_tokens: 0,
1440 requests_with_cache: 0,
1441 };
1442
1443 assert_eq!(summary.token_efficiency(), 0.0);
1445 }
1446
1447 #[test]
1452 fn test_builder_basic() {
1453 let summary = MetricsSummaryBuilder::new()
1454 .requests(10)
1455 .prompt_tokens(4000)
1456 .completion_tokens(1000)
1457 .time_ms(1500)
1458 .build();
1459
1460 assert_eq!(summary.total_requests, 10);
1461 assert_eq!(summary.prompt_tokens, 4000);
1462 assert_eq!(summary.completion_tokens, 1000);
1463 assert_eq!(summary.total_tokens, 5000); assert_eq!(summary.total_time_ms, 1500);
1465 }
1466
1467 #[test]
1468 fn test_builder_with_cache() {
1469 let summary = MetricsSummaryBuilder::new()
1470 .requests(5)
1471 .prompt_tokens(10000)
1472 .completion_tokens(2000)
1473 .cache_hit_tokens(8000)
1474 .build();
1475
1476 assert!((summary.cache_hit_rate() - 80.0).abs() < 0.01);
1478 assert_eq!(summary.requests_with_cache, 1);
1480 }
1481
1482 #[test]
1483 fn test_builder_default() {
1484 let summary = MetricsSummaryBuilder::new().build();
1485
1486 assert_eq!(summary.total_requests, 0);
1487 assert_eq!(summary.prompt_tokens, 0);
1488 assert_eq!(summary.completion_tokens, 0);
1489 assert_eq!(summary.cache_hit_tokens, 0);
1490 }
1491
1492 #[test]
1493 fn test_builder_metrics_calculation() {
1494 let summary = MetricsSummaryBuilder::new()
1495 .requests(10)
1496 .prompt_tokens(5000)
1497 .completion_tokens(1000)
1498 .cache_hit_tokens(4000)
1499 .time_ms(2000)
1500 .build();
1501
1502 assert!((summary.average_latency_ms() - 200.0).abs() < 0.01);
1504
1505 assert!((summary.output_tokens_per_second() - 500.0).abs() < 0.01);
1507
1508 assert!((summary.token_efficiency() - 20.0).abs() < 0.01);
1510 }
1511
1512 #[test]
1513 fn test_llm_request_with_tools() {
1514 let tools = vec![ToolDefinition::function(
1515 "get_weather",
1516 "Get weather data",
1517 serde_json::json!({}),
1518 )];
1519 let request = LLMRequest::new(vec![ChatMessage::user("Hi")], "p", "m").with_tools(tools);
1520 assert_eq!(request.tool_count(), 1);
1521 assert!(request.tools.is_some());
1522 }
1523
1524 #[tokio::test]
1525 async fn test_metrics_tool_call_tracking() {
1526 let metrics = Arc::new(MetricsLLMMiddleware::new());
1527 let mut stack = LLMMiddlewareStack::new();
1528 stack.add(metrics.clone());
1529
1530 let request = create_test_request();
1531
1532 let mut response = LLMResponse::new("", "m").with_usage(10, 5);
1534 response.tool_calls = vec![crate::traits::ToolCall {
1535 id: "call_1".to_string(),
1536 call_type: "function".to_string(),
1537 function: crate::traits::FunctionCall {
1538 name: "test".to_string(),
1539 arguments: "{}".to_string(),
1540 },
1541 }];
1542
1543 stack.before(&request).await.unwrap();
1544 stack.after(&request, &response, 100).await.unwrap();
1545
1546 let summary = metrics.get_summary();
1547 assert_eq!(summary.tool_call_requests, 1);
1548 }
1549
1550 #[tokio::test]
1551 async fn test_logging_middleware_debug_level() {
1552 let logging = LoggingLLMMiddleware::with_level(LogLevel::Debug);
1553 assert_eq!(logging.name(), "logging");
1554
1555 let request = create_test_request();
1556 let response = create_test_response();
1557
1558 logging.before(&request).await.unwrap();
1560 logging.after(&request, &response, 100).await.unwrap();
1561 }
1562
1563 #[tokio::test]
1564 async fn test_logging_middleware_trace_level() {
1565 let logging = LoggingLLMMiddleware::with_level(LogLevel::Trace);
1566 let request = create_test_request();
1567 let response = create_test_response();
1568
1569 logging.before(&request).await.unwrap();
1570 logging.after(&request, &response, 100).await.unwrap();
1571 }
1572
1573 #[test]
1574 fn test_logging_middleware_default() {
1575 let logging = LoggingLLMMiddleware::default();
1576 assert_eq!(logging.name(), "logging");
1577 }
1578
1579 #[test]
1580 fn test_metrics_middleware_default() {
1581 let metrics = MetricsLLMMiddleware::default();
1582 assert_eq!(metrics.get_total_requests(), 0);
1583 assert_eq!(metrics.get_total_tokens(), 0);
1584 }
1585
1586 #[test]
1587 fn test_middleware_stack_default() {
1588 let stack = LLMMiddlewareStack::default();
1589 assert!(stack.is_empty());
1590 assert_eq!(stack.len(), 0);
1591 }
1592
1593 #[test]
1594 fn test_builder_tool_calls() {
1595 let summary = MetricsSummaryBuilder::new()
1596 .requests(5)
1597 .tool_calls(3)
1598 .build();
1599 assert_eq!(summary.tool_call_requests, 3);
1600 }
1601
1602 #[test]
1603 fn test_builder_requests_with_cache_override() {
1604 let summary = MetricsSummaryBuilder::new()
1605 .requests(10)
1606 .requests_with_cache(7)
1607 .build();
1608 assert_eq!(summary.requests_with_cache, 7);
1609 }
1610
1611 #[tokio::test]
1612 async fn test_logging_debug_long_message() {
1613 let logging = LoggingLLMMiddleware::with_level(LogLevel::Debug);
1614 let long_msg = "x".repeat(200);
1616 let request = LLMRequest::new(vec![ChatMessage::user(&long_msg)], "p", "m");
1617 logging.before(&request).await.unwrap();
1619 }
1620
1621 #[tokio::test]
1622 async fn test_logging_debug_long_response() {
1623 let logging = LoggingLLMMiddleware::with_level(LogLevel::Debug);
1624 let request = create_test_request();
1625 let long_content = "y".repeat(300);
1626 let response = LLMResponse::new(&long_content, "m").with_usage(10, 5);
1627 logging.after(&request, &response, 100).await.unwrap();
1629 }
1630
1631 #[test]
1632 fn test_metrics_cache_hit_rate_no_prompts() {
1633 let metrics = MetricsLLMMiddleware::new();
1634 assert_eq!(metrics.get_cache_hit_rate(), 0.0);
1636 }
1637}