use serde::{Deserialize, Serialize};
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::time::Instant;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum Capability {
LlmCompletion,
Embedding,
Reranking,
VectorSearch,
WebSearch,
GraphSearch,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum Region {
US,
EU,
CN,
Local,
}
impl std::fmt::Display for Region {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::US => write!(f, "US"),
Self::EU => write!(f, "EU"),
Self::CN => write!(f, "CN"),
Self::Local => write!(f, "Local"),
}
}
}
#[derive(Debug, Clone)]
pub struct ProviderMeta {
pub name: &'static str,
pub version: &'static str,
pub capabilities: &'static [Capability],
pub vendor: &'static str,
pub region: Region,
}
impl ProviderMeta {
#[must_use]
pub const fn new(
name: &'static str,
version: &'static str,
capabilities: &'static [Capability],
vendor: &'static str,
region: Region,
) -> Self {
Self {
name,
version,
capabilities,
vendor,
region,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderCallContext {
pub root_intent_id: Option<String>,
pub trace_id: String,
pub user_id: Option<String>,
pub timeout_ms: u64,
pub max_cost: Option<f64>,
pub max_tokens: Option<u32>,
}
impl Default for ProviderCallContext {
fn default() -> Self {
Self {
root_intent_id: None,
trace_id: generate_trace_id(),
user_id: None,
timeout_ms: 30_000, max_cost: None,
max_tokens: None,
}
}
}
impl ProviderCallContext {
pub fn with_trace_id(trace_id: impl Into<String>) -> Self {
Self {
trace_id: trace_id.into(),
..Default::default()
}
}
#[must_use]
pub fn with_root_intent(mut self, root_intent_id: impl Into<String>) -> Self {
self.root_intent_id = Some(root_intent_id.into());
self
}
#[must_use]
pub fn with_user(mut self, user_id: impl Into<String>) -> Self {
self.user_id = Some(user_id.into());
self
}
#[must_use]
pub fn with_timeout_ms(mut self, timeout_ms: u64) -> Self {
self.timeout_ms = timeout_ms;
self
}
#[must_use]
pub fn with_max_cost(mut self, max_cost: f64) -> Self {
self.max_cost = Some(max_cost);
self
}
#[must_use]
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = Some(max_tokens);
self
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct TokenUsage {
pub input_tokens: u32,
pub output_tokens: u32,
}
impl TokenUsage {
#[must_use]
pub const fn new(input_tokens: u32, output_tokens: u32) -> Self {
Self {
input_tokens,
output_tokens,
}
}
#[must_use]
pub const fn total(&self) -> u32 {
self.input_tokens + self.output_tokens
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderObservation<T> {
pub observation_id: String,
pub request_hash: String,
pub vendor: String,
pub model: String,
pub latency_ms: u64,
pub cost_estimate: Option<f64>,
pub tokens: Option<TokenUsage>,
pub content: T,
#[serde(skip_serializing_if = "Option::is_none")]
pub raw_response: Option<String>,
}
impl<T> ProviderObservation<T> {
pub fn new(
vendor: impl Into<String>,
model: impl Into<String>,
content: T,
latency_ms: u64,
) -> Self {
let observation_id = generate_observation_id();
Self {
observation_id,
request_hash: String::new(),
vendor: vendor.into(),
model: model.into(),
latency_ms,
cost_estimate: None,
tokens: None,
content,
raw_response: None,
}
}
#[must_use]
pub fn with_request_hash(mut self, hash: impl Into<String>) -> Self {
self.request_hash = hash.into();
self
}
#[must_use]
pub fn with_cost(mut self, cost: f64) -> Self {
self.cost_estimate = Some(cost);
self
}
#[must_use]
pub fn with_tokens(mut self, input: u32, output: u32) -> Self {
self.tokens = Some(TokenUsage::new(input, output));
self
}
#[must_use]
pub fn with_raw_response(mut self, raw: impl Into<String>) -> Self {
let raw = raw.into();
const MAX_RAW_SIZE: usize = 10_000;
if raw.len() > MAX_RAW_SIZE {
self.raw_response = Some(format!("{}...[truncated]", &raw[..MAX_RAW_SIZE]));
} else {
self.raw_response = Some(raw);
}
self
}
pub fn provenance(&self) -> String {
format!("{}:{}:{}", self.vendor, self.model, self.observation_id)
}
}
pub struct CallTimer {
start: Instant,
}
impl CallTimer {
#[must_use]
pub fn start() -> Self {
Self {
start: Instant::now(),
}
}
#[must_use]
pub fn elapsed_ms(&self) -> u64 {
self.start.elapsed().as_millis() as u64
}
}
#[must_use]
pub fn canonical_hash(data: &str) -> String {
let mut hasher = DefaultHasher::new();
data.hash(&mut hasher);
format!("hash:{:016x}", hasher.finish())
}
fn generate_observation_id() -> String {
use std::sync::atomic::{AtomicU64, Ordering};
static COUNTER: AtomicU64 = AtomicU64::new(0);
let count = COUNTER.fetch_add(1, Ordering::Relaxed);
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_millis())
.unwrap_or(0);
format!("obs-{timestamp:x}-{count:x}")
}
fn generate_trace_id() -> String {
use std::sync::atomic::{AtomicU64, Ordering};
static COUNTER: AtomicU64 = AtomicU64::new(0);
let count = COUNTER.fetch_add(1, Ordering::Relaxed);
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_millis())
.unwrap_or(0);
format!("trace-{timestamp:x}-{count:x}")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_provider_meta() {
static CAPS: &[Capability] = &[Capability::LlmCompletion];
let meta = ProviderMeta::new("test", "1.0", CAPS, "test-vendor", Region::US);
assert_eq!(meta.name, "test");
assert_eq!(meta.region, Region::US);
}
#[test]
fn test_call_context_default() {
let ctx = ProviderCallContext::default();
assert_eq!(ctx.timeout_ms, 30_000);
assert!(ctx.trace_id.starts_with("trace-"));
}
#[test]
fn test_call_context_builder() {
let ctx = ProviderCallContext::default()
.with_root_intent("intent-123")
.with_user("user-456")
.with_timeout_ms(5000)
.with_max_cost(1.0)
.with_max_tokens(1000);
assert_eq!(ctx.root_intent_id, Some("intent-123".into()));
assert_eq!(ctx.user_id, Some("user-456".into()));
assert_eq!(ctx.timeout_ms, 5000);
assert_eq!(ctx.max_cost, Some(1.0));
assert_eq!(ctx.max_tokens, Some(1000));
}
#[test]
fn test_token_usage() {
let usage = TokenUsage::new(100, 50);
assert_eq!(usage.total(), 150);
}
#[test]
fn test_observation_provenance() {
let obs = ProviderObservation::new("anthropic", "claude-3", "content", 100);
let prov = obs.provenance();
assert!(prov.starts_with("anthropic:claude-3:obs-"));
}
#[test]
fn test_observation_builder() {
let obs = ProviderObservation::new("openai", "gpt-4", "response", 500)
.with_request_hash("hash:abc123")
.with_cost(0.05)
.with_tokens(100, 50);
assert_eq!(obs.request_hash, "hash:abc123");
assert_eq!(obs.cost_estimate, Some(0.05));
assert_eq!(obs.tokens.unwrap().total(), 150);
}
#[test]
fn test_raw_response_truncation() {
let long_response = "x".repeat(20_000);
let obs = ProviderObservation::new("test", "model", "content", 100)
.with_raw_response(long_response);
let raw = obs.raw_response.unwrap();
assert!(raw.ends_with("...[truncated]"));
assert!(raw.len() < 15_000);
}
#[test]
fn test_canonical_hash_deterministic() {
let hash1 = canonical_hash("test input");
let hash2 = canonical_hash("test input");
assert_eq!(hash1, hash2);
let hash3 = canonical_hash("different input");
assert_ne!(hash1, hash3);
}
#[test]
fn test_call_timer() {
let timer = CallTimer::start();
std::thread::sleep(std::time::Duration::from_millis(10));
let elapsed = timer.elapsed_ms();
assert!(elapsed >= 10);
}
#[test]
fn test_observation_ids_unique() {
let obs1 = ProviderObservation::new("test", "model", "a", 1);
let obs2 = ProviderObservation::new("test", "model", "b", 2);
assert_ne!(obs1.observation_id, obs2.observation_id);
}
}