1use serde::{Deserialize, Serialize};
10use std::collections::hash_map::DefaultHasher;
11use std::hash::{Hash, Hasher};
12use std::time::Instant;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
16pub enum Capability {
17 LlmCompletion,
19 Embedding,
21 Reranking,
23 VectorSearch,
25 WebSearch,
27 GraphSearch,
29}
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
33pub enum Region {
34 US,
36 EU,
38 CN,
40 Local,
42}
43
44impl std::fmt::Display for Region {
45 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46 match self {
47 Self::US => write!(f, "US"),
48 Self::EU => write!(f, "EU"),
49 Self::CN => write!(f, "CN"),
50 Self::Local => write!(f, "Local"),
51 }
52 }
53}
54
55#[derive(Debug, Clone)]
59pub struct ProviderMeta {
60 pub name: &'static str,
62 pub version: &'static str,
64 pub capabilities: &'static [Capability],
66 pub vendor: &'static str,
68 pub region: Region,
70}
71
72impl ProviderMeta {
73 #[must_use]
75 pub const fn new(
76 name: &'static str,
77 version: &'static str,
78 capabilities: &'static [Capability],
79 vendor: &'static str,
80 region: Region,
81 ) -> Self {
82 Self {
83 name,
84 version,
85 capabilities,
86 vendor,
87 region,
88 }
89 }
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct ProviderCallContext {
97 pub root_intent_id: Option<String>,
99 pub trace_id: String,
101 pub user_id: Option<String>,
103 pub timeout_ms: u64,
105 pub max_cost: Option<f64>,
107 pub max_tokens: Option<u32>,
109}
110
111impl Default for ProviderCallContext {
112 fn default() -> Self {
113 Self {
114 root_intent_id: None,
115 trace_id: generate_trace_id(),
116 user_id: None,
117 timeout_ms: 30_000, max_cost: None,
119 max_tokens: None,
120 }
121 }
122}
123
124impl ProviderCallContext {
125 pub fn with_trace_id(trace_id: impl Into<String>) -> Self {
127 Self {
128 trace_id: trace_id.into(),
129 ..Default::default()
130 }
131 }
132
133 #[must_use]
135 pub fn with_root_intent(mut self, root_intent_id: impl Into<String>) -> Self {
136 self.root_intent_id = Some(root_intent_id.into());
137 self
138 }
139
140 #[must_use]
142 pub fn with_user(mut self, user_id: impl Into<String>) -> Self {
143 self.user_id = Some(user_id.into());
144 self
145 }
146
147 #[must_use]
149 pub fn with_timeout_ms(mut self, timeout_ms: u64) -> Self {
150 self.timeout_ms = timeout_ms;
151 self
152 }
153
154 #[must_use]
156 pub fn with_max_cost(mut self, max_cost: f64) -> Self {
157 self.max_cost = Some(max_cost);
158 self
159 }
160
161 #[must_use]
163 pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
164 self.max_tokens = Some(max_tokens);
165 self
166 }
167}
168
169#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
171pub struct TokenUsage {
172 pub input_tokens: u32,
174 pub output_tokens: u32,
176}
177
178impl TokenUsage {
179 #[must_use]
181 pub const fn new(input_tokens: u32, output_tokens: u32) -> Self {
182 Self {
183 input_tokens,
184 output_tokens,
185 }
186 }
187
188 #[must_use]
190 pub const fn total(&self) -> u32 {
191 self.input_tokens + self.output_tokens
192 }
193}
194
195#[derive(Debug, Clone, Serialize, Deserialize)]
206pub struct ProviderObservation<T> {
207 pub observation_id: String,
209 pub request_hash: String,
211 pub vendor: String,
213 pub model: String,
215 pub latency_ms: u64,
217 pub cost_estimate: Option<f64>,
219 pub tokens: Option<TokenUsage>,
221 pub content: T,
223 #[serde(skip_serializing_if = "Option::is_none")]
225 pub raw_response: Option<String>,
226}
227
228impl<T> ProviderObservation<T> {
229 pub fn new(
231 vendor: impl Into<String>,
232 model: impl Into<String>,
233 content: T,
234 latency_ms: u64,
235 ) -> Self {
236 let observation_id = generate_observation_id();
237 Self {
238 observation_id,
239 request_hash: String::new(),
240 vendor: vendor.into(),
241 model: model.into(),
242 latency_ms,
243 cost_estimate: None,
244 tokens: None,
245 content,
246 raw_response: None,
247 }
248 }
249
250 #[must_use]
252 pub fn with_request_hash(mut self, hash: impl Into<String>) -> Self {
253 self.request_hash = hash.into();
254 self
255 }
256
257 #[must_use]
259 pub fn with_cost(mut self, cost: f64) -> Self {
260 self.cost_estimate = Some(cost);
261 self
262 }
263
264 #[must_use]
266 pub fn with_tokens(mut self, input: u32, output: u32) -> Self {
267 self.tokens = Some(TokenUsage::new(input, output));
268 self
269 }
270
271 #[must_use]
273 pub fn with_raw_response(mut self, raw: impl Into<String>) -> Self {
274 let raw = raw.into();
275 const MAX_RAW_SIZE: usize = 10_000;
276 if raw.len() > MAX_RAW_SIZE {
277 self.raw_response = Some(format!("{}...[truncated]", &raw[..MAX_RAW_SIZE]));
278 } else {
279 self.raw_response = Some(raw);
280 }
281 self
282 }
283
284 pub fn provenance(&self) -> String {
289 format!("{}:{}:{}", self.vendor, self.model, self.observation_id)
290 }
291}
292
293pub struct CallTimer {
297 start: Instant,
298}
299
300impl CallTimer {
301 #[must_use]
303 pub fn start() -> Self {
304 Self {
305 start: Instant::now(),
306 }
307 }
308
309 #[must_use]
311 pub fn elapsed_ms(&self) -> u64 {
312 self.start.elapsed().as_millis() as u64
313 }
314}
315
316#[must_use]
321pub fn canonical_hash(data: &str) -> String {
322 let mut hasher = DefaultHasher::new();
323 data.hash(&mut hasher);
324 format!("hash:{:016x}", hasher.finish())
325}
326
327fn generate_observation_id() -> String {
329 use std::sync::atomic::{AtomicU64, Ordering};
330 static COUNTER: AtomicU64 = AtomicU64::new(0);
331
332 let count = COUNTER.fetch_add(1, Ordering::Relaxed);
333 let timestamp = std::time::SystemTime::now()
334 .duration_since(std::time::UNIX_EPOCH)
335 .map(|d| d.as_millis())
336 .unwrap_or(0);
337
338 format!("obs-{timestamp:x}-{count:x}")
339}
340
341fn generate_trace_id() -> String {
343 use std::sync::atomic::{AtomicU64, Ordering};
344 static COUNTER: AtomicU64 = AtomicU64::new(0);
345
346 let count = COUNTER.fetch_add(1, Ordering::Relaxed);
347 let timestamp = std::time::SystemTime::now()
348 .duration_since(std::time::UNIX_EPOCH)
349 .map(|d| d.as_millis())
350 .unwrap_or(0);
351
352 format!("trace-{timestamp:x}-{count:x}")
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358
359 #[test]
360 fn test_provider_meta() {
361 static CAPS: &[Capability] = &[Capability::LlmCompletion];
362 let meta = ProviderMeta::new("test", "1.0", CAPS, "test-vendor", Region::US);
363 assert_eq!(meta.name, "test");
364 assert_eq!(meta.region, Region::US);
365 }
366
367 #[test]
368 fn test_call_context_default() {
369 let ctx = ProviderCallContext::default();
370 assert_eq!(ctx.timeout_ms, 30_000);
371 assert!(ctx.trace_id.starts_with("trace-"));
372 }
373
374 #[test]
375 fn test_call_context_builder() {
376 let ctx = ProviderCallContext::default()
377 .with_root_intent("intent-123")
378 .with_user("user-456")
379 .with_timeout_ms(5000)
380 .with_max_cost(1.0)
381 .with_max_tokens(1000);
382
383 assert_eq!(ctx.root_intent_id, Some("intent-123".into()));
384 assert_eq!(ctx.user_id, Some("user-456".into()));
385 assert_eq!(ctx.timeout_ms, 5000);
386 assert_eq!(ctx.max_cost, Some(1.0));
387 assert_eq!(ctx.max_tokens, Some(1000));
388 }
389
390 #[test]
391 fn test_token_usage() {
392 let usage = TokenUsage::new(100, 50);
393 assert_eq!(usage.total(), 150);
394 }
395
396 #[test]
397 fn test_observation_provenance() {
398 let obs = ProviderObservation::new("anthropic", "claude-3", "content", 100);
399 let prov = obs.provenance();
400 assert!(prov.starts_with("anthropic:claude-3:obs-"));
401 }
402
403 #[test]
404 fn test_observation_builder() {
405 let obs = ProviderObservation::new("openai", "gpt-4", "response", 500)
406 .with_request_hash("hash:abc123")
407 .with_cost(0.05)
408 .with_tokens(100, 50);
409
410 assert_eq!(obs.request_hash, "hash:abc123");
411 assert_eq!(obs.cost_estimate, Some(0.05));
412 assert_eq!(obs.tokens.unwrap().total(), 150);
413 }
414
415 #[test]
416 fn test_raw_response_truncation() {
417 let long_response = "x".repeat(20_000);
418 let obs = ProviderObservation::new("test", "model", "content", 100)
419 .with_raw_response(long_response);
420
421 let raw = obs.raw_response.unwrap();
422 assert!(raw.ends_with("...[truncated]"));
423 assert!(raw.len() < 15_000);
424 }
425
426 #[test]
427 fn test_canonical_hash_deterministic() {
428 let hash1 = canonical_hash("test input");
429 let hash2 = canonical_hash("test input");
430 assert_eq!(hash1, hash2);
431
432 let hash3 = canonical_hash("different input");
433 assert_ne!(hash1, hash3);
434 }
435
436 #[test]
437 fn test_call_timer() {
438 let timer = CallTimer::start();
439 std::thread::sleep(std::time::Duration::from_millis(10));
440 let elapsed = timer.elapsed_ms();
441 assert!(elapsed >= 10);
442 }
443
444 #[test]
445 fn test_observation_ids_unique() {
446 let obs1 = ProviderObservation::new("test", "model", "a", 1);
447 let obs2 = ProviderObservation::new("test", "model", "b", 2);
448 assert_ne!(obs1.observation_id, obs2.observation_id);
449 }
450}