1use crate::error::EvaluationError;
2use regex::Regex;
3use scouter_types::genai::{AggregationType, SpanFilter, SpanStatus, TraceAssertion};
7use scouter_types::sql::TraceSpan;
8use serde_json::{json, Value};
9use std::collections::HashSet;
10use std::sync::Arc;
11use tracing::debug;
12
13#[derive(Debug, Clone)]
14pub struct TraceContextBuilder {
15 pub(crate) spans: Arc<Vec<TraceSpan>>,
17}
18
19impl TraceContextBuilder {
20 pub fn new(spans: Arc<Vec<TraceSpan>>) -> Self {
21 Self { spans }
22 }
23
24 pub fn build_context(&self, assertion: &TraceAssertion) -> Result<Value, EvaluationError> {
26 match assertion {
27 TraceAssertion::SpanSequence { span_names } => {
28 Ok(json!(self.match_span_sequence(span_names)?))
29 }
30 TraceAssertion::SpanSet { span_names } => Ok(json!(self.match_span_set(span_names)?)),
31 TraceAssertion::SpanCount { filter } => Ok(json!(self.count_spans(filter)?)),
32 TraceAssertion::SpanExists { filter } => Ok(json!(self.span_exists(filter)?)),
33 TraceAssertion::SpanAttribute {
34 filter,
35 attribute_key,
36 } => self.extract_span_attribute(filter, attribute_key),
37 TraceAssertion::SpanDuration { filter } => self.extract_span_duration(filter),
38 TraceAssertion::SpanAggregation {
39 filter,
40 attribute_key,
41 aggregation,
42 } => self.aggregate_span_attribute(filter, attribute_key, aggregation),
43 TraceAssertion::TraceDuration {} => Ok(json!(self.calculate_trace_duration())),
44 TraceAssertion::TraceSpanCount {} => Ok(json!(self.spans.len())),
45 TraceAssertion::TraceErrorCount {} => Ok(json!(self.count_error_spans())),
46 TraceAssertion::TraceServiceCount {} => Ok(json!(self.count_unique_services())),
47 TraceAssertion::TraceMaxDepth {} => Ok(json!(self.calculate_max_depth())),
48 TraceAssertion::TraceAttribute { attribute_key } => {
49 self.extract_trace_attribute(attribute_key)
50 }
51 }
52 }
53
54 fn filter_spans(&self, filter: &SpanFilter) -> Result<Vec<&TraceSpan>, EvaluationError> {
56 let mut filtered = Vec::new();
57
58 for span in self.spans.iter() {
59 if self.matches_filter(span, filter)? {
60 filtered.push(span);
61 }
62 }
63
64 debug!(
65 "Filtered spans count: {} with filter {:?}",
66 filtered.len(),
67 filter
68 );
69
70 Ok(filtered)
71 }
72
73 fn matches_filter(
74 &self,
75 span: &TraceSpan,
76 filter: &SpanFilter,
77 ) -> Result<bool, EvaluationError> {
78 match filter {
79 SpanFilter::ByName { name } => Ok(span.span_name == *name),
80
81 SpanFilter::ByNamePattern { pattern } => {
82 let regex = Regex::new(pattern)?;
83 Ok(regex.is_match(&span.span_name))
84 }
85
86 SpanFilter::WithAttribute { key } => {
87 Ok(span.attributes.iter().any(|attr| attr.key == *key))
88 }
89
90 SpanFilter::WithAttributeValue { key, value } => {
91 Ok(span.attributes.iter().any(|attr| {
92 attr.key == *key && self.attribute_value_matches(&attr.value, &value.0)
93 }))
94 }
95
96 SpanFilter::WithStatus { status } => {
97 Ok(self.map_status_code(span.status_code) == *status)
98 }
99
100 SpanFilter::WithDuration { min_ms, max_ms } => {
101 let duration_f64 = span.duration_ms as f64;
102 let min_ok = min_ms.is_none_or(|min| duration_f64 >= min);
103 let max_ok = max_ms.is_none_or(|max| duration_f64 <= max);
104 Ok(min_ok && max_ok)
105 }
106
107 SpanFilter::And { filters } => {
108 for f in filters {
109 if !self.matches_filter(span, f)? {
110 return Ok(false);
111 }
112 }
113 Ok(true)
114 }
115
116 SpanFilter::Or { filters } => {
117 for f in filters {
118 if self.matches_filter(span, f)? {
119 return Ok(true);
120 }
121 }
122 Ok(false)
123 }
124
125 SpanFilter::Sequence { .. } => Err(EvaluationError::InvalidFilter(
126 "Sequence filter not applicable to individual spans".to_string(),
127 )),
128 }
129 }
130
131 fn match_span_sequence(&self, span_names: &[String]) -> Result<bool, EvaluationError> {
133 let executed_names = self.get_ordered_span_names()?;
134 Ok(executed_names == span_names)
135 }
136
137 fn match_span_set(&self, span_names: &[String]) -> Result<bool, EvaluationError> {
139 let unique_names: HashSet<_> = self.spans.iter().map(|s| s.span_name.clone()).collect();
140 for name in span_names {
141 if !unique_names.contains(name) {
142 return Ok(false);
143 }
144 }
145 Ok(true)
146 }
147
148 fn count_spans(&self, filter: &SpanFilter) -> Result<usize, EvaluationError> {
149 match filter {
150 SpanFilter::Sequence { names } => self.count_sequence_occurrences(names),
151 _ => Ok(self.filter_spans(filter)?.len()),
152 }
153 }
154
155 fn count_sequence_occurrences(
157 &self,
158 target_sequence: &[String],
159 ) -> Result<usize, EvaluationError> {
160 if target_sequence.is_empty() {
161 return Ok(0);
162 }
163
164 let all_span_names = self.get_ordered_span_names()?;
165
166 if all_span_names.len() < target_sequence.len() {
167 return Ok(0);
168 }
169
170 Ok(all_span_names
171 .windows(target_sequence.len())
172 .filter(|window| *window == target_sequence)
173 .count())
174 }
175
176 fn get_ordered_span_names(&self) -> Result<Vec<String>, EvaluationError> {
177 let mut ordered_spans: Vec<_> = self.spans.iter().collect();
178 ordered_spans.sort_by_key(|s| s.span_order);
179
180 Ok(ordered_spans
181 .into_iter()
182 .map(|s| s.span_name.clone())
183 .collect())
184 }
185
186 fn span_exists(&self, filter: &SpanFilter) -> Result<bool, EvaluationError> {
187 Ok(!self.filter_spans(filter)?.is_empty())
188 }
189
190 fn extract_span_attribute(
191 &self,
192 filter: &SpanFilter,
193 attribute_key: &str,
194 ) -> Result<Value, EvaluationError> {
195 let filtered_spans = self.filter_spans(filter)?;
196
197 if filtered_spans.is_empty() {
198 return Ok(Value::Null);
199 }
200
201 let values: Vec<Value> = filtered_spans
202 .iter()
203 .filter_map(|span| {
204 span.attributes
205 .iter()
206 .find(|attr| attr.key == attribute_key)
207 .map(|attr| attr.value.clone())
208 })
209 .collect();
210
211 if values.len() == 1 {
212 Ok(values[0].clone())
213 } else {
214 Ok(Value::Array(values))
215 }
216 }
217
218 fn extract_span_duration(&self, filter: &SpanFilter) -> Result<Value, EvaluationError> {
219 let filtered_spans = self.filter_spans(filter)?;
220
221 let durations: Vec<i64> = filtered_spans.iter().map(|span| span.duration_ms).collect();
222
223 if durations.len() == 1 {
224 Ok(json!(durations[0]))
225 } else {
226 Ok(json!(durations))
227 }
228 }
229
230 fn aggregate_span_attribute(
231 &self,
232 filter: &SpanFilter,
233 attribute_key: &str,
234 aggregation: &AggregationType,
235 ) -> Result<Value, EvaluationError> {
236 let filtered_spans = self.filter_spans(filter)?;
237
238 match aggregation {
239 AggregationType::Count => {
240 let count = filtered_spans
241 .iter()
242 .filter(|span| span.attributes.iter().any(|attr| attr.key == attribute_key))
243 .count();
244 Ok(json!(count))
245 }
246 _ => {
247 let values: Vec<f64> = filtered_spans
248 .iter()
249 .filter_map(|span| {
250 span.attributes
251 .iter()
252 .find(|attr| attr.key == attribute_key)
253 .and_then(|attr| attr.value.as_f64())
254 })
255 .collect();
256
257 if values.is_empty() {
258 return Ok(Value::Null);
259 }
260
261 let result = match aggregation {
262 AggregationType::Count => unreachable!(),
263 AggregationType::Sum => values.iter().sum(),
264 AggregationType::Average => values.iter().sum::<f64>() / values.len() as f64,
265 AggregationType::Min => values.iter().copied().fold(f64::INFINITY, f64::min),
266 AggregationType::Max => {
267 values.iter().copied().fold(f64::NEG_INFINITY, f64::max)
268 }
269 AggregationType::First => values[0],
270 AggregationType::Last => values[values.len() - 1],
271 };
272
273 Ok(json!(result))
274 }
275 }
276 }
277
278 fn calculate_trace_duration(&self) -> i64 {
280 self.spans.iter().map(|s| s.duration_ms).max().unwrap_or(0)
281 }
282
283 fn count_error_spans(&self) -> usize {
284 self.spans
285 .iter()
286 .filter(|s| s.status_code == 2) .count()
288 }
289
290 fn count_unique_services(&self) -> usize {
291 self.spans
292 .iter()
293 .map(|s| &s.service_name)
294 .collect::<HashSet<_>>()
295 .len()
296 }
297
298 fn calculate_max_depth(&self) -> i32 {
299 self.spans.iter().map(|s| s.depth).max().unwrap_or(0)
300 }
301
302 fn extract_trace_attribute(&self, attribute_key: &str) -> Result<Value, EvaluationError> {
303 let root_span = self
304 .spans
305 .iter()
306 .find(|s| s.depth == 0)
307 .ok_or_else(|| EvaluationError::NoRootSpan)?;
308
309 root_span
310 .attributes
311 .iter()
312 .find(|attr| attr.key == attribute_key)
313 .map(|attr| attr.value.clone())
314 .ok_or_else(|| EvaluationError::AttributeNotFound(attribute_key.to_string()))
315 }
316
317 fn map_status_code(&self, code: i32) -> SpanStatus {
319 match code {
320 0 => SpanStatus::Unset,
321 1 => SpanStatus::Ok,
322 2 => SpanStatus::Error,
323 _ => SpanStatus::Unset,
324 }
325 }
326
327 fn attribute_value_matches(&self, attr_value: &Value, expected: &Value) -> bool {
328 attr_value == expected
329 }
330}
331
332#[cfg(test)]
333mod tests {
334 use scouter_types::genai::PyValueWrapper;
335
336 use super::*;
337
338 use scouter_mocks::{
339 create_multi_service_trace, create_nested_trace, create_sequence_pattern_trace,
340 create_simple_trace, create_trace_with_attributes, create_trace_with_errors,
341 };
342
343 #[test]
344 fn test_simple_trace_structure() {
345 let spans = create_simple_trace();
346 assert_eq!(spans.len(), 3);
347 assert_eq!(spans[0].span_name, "root");
348 assert_eq!(spans[0].depth, 0);
349 assert_eq!(
350 spans[1].parent_span_id,
351 Some("7370616e5f300000".to_string()) );
353 }
354
355 #[test]
356 fn test_nested_trace_depth() {
357 let spans = create_nested_trace();
358 let builder = TraceContextBuilder::new(Arc::new(spans));
359 assert_eq!(builder.calculate_max_depth(), 2);
360 }
361
362 #[test]
363 fn test_error_counting() {
364 let spans = create_trace_with_errors();
365 let builder = TraceContextBuilder::new(Arc::new(spans));
366 assert_eq!(builder.count_error_spans(), 1);
367 }
368
369 #[test]
370 fn test_attribute_filtering() {
371 let spans = create_trace_with_attributes();
372 let builder = TraceContextBuilder::new(Arc::new(spans));
373
374 let filter = SpanFilter::WithAttribute {
375 key: "model".to_string(),
376 };
377
378 let result = builder.span_exists(&filter).unwrap();
379 assert!(result);
380 }
381
382 #[test]
383 fn test_sequence_pattern_detection() {
384 let spans = create_sequence_pattern_trace();
385 let builder = TraceContextBuilder::new(Arc::new(spans));
386
387 let filter = SpanFilter::Sequence {
388 names: vec!["call_tool".to_string(), "run_agent".to_string()],
389 };
390
391 let count = builder.count_spans(&filter).unwrap();
392 assert_eq!(count, 2);
393 }
394
395 #[test]
396 fn test_multi_service_trace() {
397 let spans = create_multi_service_trace();
398 let builder = TraceContextBuilder::new(Arc::new(spans));
399 assert_eq!(builder.count_unique_services(), 3);
400 }
401
402 #[test]
403 fn test_aggregation_with_numeric_attributes() {
404 let spans = create_trace_with_attributes();
405 let builder = TraceContextBuilder::new(Arc::new(spans));
406
407 let filter = SpanFilter::WithAttribute {
408 key: "tokens.input".to_string(),
409 };
410
411 let result = builder
412 .aggregate_span_attribute(&filter, "tokens.input", &AggregationType::Sum)
413 .unwrap();
414
415 assert_eq!(result, json!(150.0));
416 }
417
418 #[test]
419 fn test_trace_assertion_span_sequence_evaluation() {
420 let spans = create_simple_trace();
421 let builder = TraceContextBuilder::new(Arc::new(spans));
422
423 let assertion = TraceAssertion::SpanSequence {
424 span_names: vec![
425 "root".to_string(),
426 "child_1".to_string(),
427 "child_2".to_string(),
428 ],
429 };
430
431 let context = builder.build_context(&assertion).unwrap();
432 assert_eq!(context, json!(true));
433 }
434
435 #[test]
436 fn test_trace_assertion_span_set_evaluation() {
437 let spans = create_simple_trace();
438 let builder = TraceContextBuilder::new(Arc::new(spans));
439
440 let assertion = TraceAssertion::SpanSet {
441 span_names: vec![
442 "root".to_string(),
443 "child_1".to_string(),
444 "child_2".to_string(),
445 ],
446 };
447
448 let context = builder.build_context(&assertion).unwrap();
449 assert_eq!(context, json!(true));
450 }
451
452 #[test]
453 fn test_trace_assertion_span_count() {
454 let spans = create_simple_trace();
455 let builder = TraceContextBuilder::new(Arc::new(spans));
456
457 let filter = SpanFilter::ByName {
458 name: "child_1".to_string(),
459 };
460
461 let assertion = TraceAssertion::SpanCount { filter };
462
463 let context = builder.build_context(&assertion).unwrap();
464 assert_eq!(context, json!(1));
465
466 let filter_pattern = SpanFilter::ByNamePattern {
468 pattern: "^child_.*".to_string(),
469 };
470
471 let assertion_pattern = TraceAssertion::SpanCount {
472 filter: filter_pattern,
473 };
474 let context_pattern = builder.build_context(&assertion_pattern).unwrap();
475 assert_eq!(context_pattern, json!(2));
476
477 let trace_with_attributes = create_trace_with_attributes();
479 let builder_attr = TraceContextBuilder::new(Arc::new(trace_with_attributes));
480
481 let filter_attr = SpanFilter::WithAttribute {
482 key: "model".to_string(),
483 };
484
485 let assertion_attr = TraceAssertion::SpanCount {
486 filter: filter_attr,
487 };
488 let context_attr = builder_attr.build_context(&assertion_attr).unwrap();
489 assert_eq!(context_attr, json!(1));
490
491 let filter_attr_value = SpanFilter::WithAttributeValue {
493 key: "http.method".to_string(),
494 value: PyValueWrapper(json!("POST")),
495 };
496
497 let assertion_attr_value = TraceAssertion::SpanCount {
498 filter: filter_attr_value,
499 };
500 let context_attr_value = builder_attr.build_context(&assertion_attr_value).unwrap();
501 assert_eq!(context_attr_value, json!(1));
502
503 let filter_status = SpanFilter::WithStatus {
505 status: SpanStatus::Ok,
506 };
507 let assertion_status = TraceAssertion::SpanCount {
508 filter: filter_status,
509 };
510 let context_status = builder_attr.build_context(&assertion_status).unwrap();
511 assert_eq!(context_status, json!(2));
512
513 let filter_duration = SpanFilter::WithDuration {
515 min_ms: Some(80.0),
516 max_ms: Some(120.0),
517 };
518 let assertion_duration = TraceAssertion::SpanCount {
519 filter: filter_duration,
520 };
521 let context_duration = builder_attr.build_context(&assertion_duration).unwrap();
522 assert_eq!(context_duration, json!(1));
523
524 let filter_and = SpanFilter::And {
526 filters: vec![
527 SpanFilter::WithAttribute {
528 key: "http.method".to_string(),
529 },
530 SpanFilter::WithStatus {
531 status: SpanStatus::Ok,
532 },
533 ],
534 };
535 let assertion_and = TraceAssertion::SpanCount { filter: filter_and };
536 let context_and = builder_attr.build_context(&assertion_and).unwrap();
537 assert_eq!(context_and, json!(1));
538
539 let filter_or = SpanFilter::Or {
541 filters: vec![
542 SpanFilter::WithAttributeValue {
543 key: "http.method".to_string(),
544 value: PyValueWrapper(json!("GET")),
545 },
546 SpanFilter::WithAttributeValue {
547 key: "model".to_string(),
548 value: PyValueWrapper(json!("gpt-4")),
549 },
550 ],
551 };
552 let assertion_or = TraceAssertion::SpanCount { filter: filter_or };
553 let context_or = builder_attr.build_context(&assertion_or).unwrap();
554 assert_eq!(context_or, json!(1));
555 }
556
557 #[test]
558 fn test_span_exists() {
559 let spans = create_simple_trace();
560 let builder = TraceContextBuilder::new(Arc::new(spans));
561 let filter = SpanFilter::ByName {
562 name: "child_1".to_string(),
563 };
564 let assertion = TraceAssertion::SpanExists { filter };
565 let context = builder.build_context(&assertion).unwrap();
566 assert_eq!(context, json!(true));
567 }
568
569 #[test]
570 fn test_span_attribute() {
571 let spans = create_trace_with_attributes();
573 let builder = TraceContextBuilder::new(Arc::new(spans));
574 let filter = SpanFilter::ByName {
575 name: "api_call".to_string(),
576 };
577 let assertion = TraceAssertion::SpanAttribute {
578 filter,
579 attribute_key: "model".to_string(),
580 };
581 let context = builder.build_context(&assertion).unwrap();
582 assert_eq!(context, json!("gpt-4"));
583
584 let spans = create_trace_with_attributes();
586 let builder = TraceContextBuilder::new(Arc::new(spans));
587 let filter = SpanFilter::ByName {
588 name: "api_call".to_string(),
589 };
590 let assertion = TraceAssertion::SpanAttribute {
591 filter,
592 attribute_key: "response".to_string(),
593 };
594 let context = builder.build_context(&assertion).unwrap();
595 assert_eq!(context, json!({"success": true, "data": {"id": 12345}}));
596 }
597
598 #[test]
599 fn test_span_attribute_aggregation() {
600 let spans = create_trace_with_attributes();
601 let builder = TraceContextBuilder::new(Arc::new(spans));
602 let filter = SpanFilter::ByName {
603 name: "api_call".to_string(),
604 };
605 let assertion = TraceAssertion::SpanAggregation {
606 filter,
607 attribute_key: "tokens.output".to_string(),
608 aggregation: AggregationType::Sum,
609 };
610 let context = builder.build_context(&assertion).unwrap();
611 assert_eq!(context, json!(300.0));
612 }
613
614 #[test]
616 fn test_sequence_pattern_counting() {
617 let spans = create_sequence_pattern_trace();
619 let builder = TraceContextBuilder::new(Arc::new(spans));
620 let filter = SpanFilter::Sequence {
621 names: vec!["call_tool".to_string(), "run_agent".to_string()],
622 };
623 let assertion = TraceAssertion::SpanCount { filter };
624 let context = builder.build_context(&assertion).unwrap();
625 assert_eq!(context, json!(2));
626
627 let spans = create_sequence_pattern_trace();
629 let builder = TraceContextBuilder::new(Arc::new(spans));
630 let filter = SpanFilter::ByName {
631 name: "call_tool".to_string(),
632 };
633 let assertion = TraceAssertion::SpanCount { filter };
634 let context = builder.build_context(&assertion).unwrap();
635 assert_eq!(context, json!(2));
636 }
637}