Skip to main content

converge_provider/
contract.rs

1// Copyright 2024-2026 Reflective Labs
2// SPDX-License-Identifier: MIT
3
4//! Provider contract types for structured observations and call context.
5//!
6//! These types define the boundary between providers (adapters) and the
7//! Converge core engine. Providers produce observations; the engine decides.
8
9use serde::{Deserialize, Serialize};
10use std::collections::hash_map::DefaultHasher;
11use std::hash::{Hash, Hasher};
12use std::time::Instant;
13
14/// Capabilities that providers can offer.
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
16pub enum Capability {
17    /// LLM text completion
18    LlmCompletion,
19    /// Text/image embedding generation
20    Embedding,
21    /// Re-ranking search results
22    Reranking,
23    /// Vector similarity search
24    VectorSearch,
25    /// Web search with citations
26    WebSearch,
27    /// Graph pattern matching
28    GraphSearch,
29}
30
31/// Data sovereignty region for a provider.
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
33pub enum Region {
34    /// United States
35    US,
36    /// European Union
37    EU,
38    /// China
39    CN,
40    /// Local (on-premise, no network)
41    Local,
42}
43
44impl std::fmt::Display for Region {
45    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46        match self {
47            Self::US => write!(f, "US"),
48            Self::EU => write!(f, "EU"),
49            Self::CN => write!(f, "CN"),
50            Self::Local => write!(f, "Local"),
51        }
52    }
53}
54
55/// Metadata about a provider implementation.
56///
57/// This is static information that describes what a provider offers.
58#[derive(Debug, Clone)]
59pub struct ProviderMeta {
60    /// Provider name (e.g., "anthropic", "openai")
61    pub name: &'static str,
62    /// Provider version
63    pub version: &'static str,
64    /// Capabilities this provider offers
65    pub capabilities: &'static [Capability],
66    /// Vendor identifier
67    pub vendor: &'static str,
68    /// Region (for data sovereignty)
69    pub region: Region,
70}
71
72impl ProviderMeta {
73    /// Create new provider metadata.
74    #[must_use]
75    pub const fn new(
76        name: &'static str,
77        version: &'static str,
78        capabilities: &'static [Capability],
79        vendor: &'static str,
80        region: Region,
81    ) -> Self {
82        Self {
83            name,
84            version,
85            capabilities,
86            vendor,
87            region,
88        }
89    }
90}
91
92/// Context passed to every provider call for tracing and budgets.
93///
94/// This provides correlation IDs, timeouts, and budget constraints.
95#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct ProviderCallContext {
97    /// Root intent ID for correlation (from converge-core)
98    pub root_intent_id: Option<String>,
99    /// Trace ID for distributed tracing
100    pub trace_id: String,
101    /// User/org identifier (for auditing)
102    pub user_id: Option<String>,
103    /// Maximum allowed latency in milliseconds (timeout)
104    pub timeout_ms: u64,
105    /// Maximum allowed cost in USD
106    pub max_cost: Option<f64>,
107    /// Maximum tokens (input + output)
108    pub max_tokens: Option<u32>,
109}
110
111impl Default for ProviderCallContext {
112    fn default() -> Self {
113        Self {
114            root_intent_id: None,
115            trace_id: generate_trace_id(),
116            user_id: None,
117            timeout_ms: 30_000, // 30 seconds
118            max_cost: None,
119            max_tokens: None,
120        }
121    }
122}
123
124impl ProviderCallContext {
125    /// Create a new call context with a specific trace ID.
126    pub fn with_trace_id(trace_id: impl Into<String>) -> Self {
127        Self {
128            trace_id: trace_id.into(),
129            ..Default::default()
130        }
131    }
132
133    /// Set the root intent ID.
134    #[must_use]
135    pub fn with_root_intent(mut self, root_intent_id: impl Into<String>) -> Self {
136        self.root_intent_id = Some(root_intent_id.into());
137        self
138    }
139
140    /// Set the user ID.
141    #[must_use]
142    pub fn with_user(mut self, user_id: impl Into<String>) -> Self {
143        self.user_id = Some(user_id.into());
144        self
145    }
146
147    /// Set the timeout in milliseconds.
148    #[must_use]
149    pub fn with_timeout_ms(mut self, timeout_ms: u64) -> Self {
150        self.timeout_ms = timeout_ms;
151        self
152    }
153
154    /// Set the maximum cost budget.
155    #[must_use]
156    pub fn with_max_cost(mut self, max_cost: f64) -> Self {
157        self.max_cost = Some(max_cost);
158        self
159    }
160
161    /// Set the maximum token budget.
162    #[must_use]
163    pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
164        self.max_tokens = Some(max_tokens);
165        self
166    }
167}
168
169/// Token usage information from a provider call.
170#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
171pub struct TokenUsage {
172    /// Number of input tokens
173    pub input_tokens: u32,
174    /// Number of output tokens
175    pub output_tokens: u32,
176}
177
178impl TokenUsage {
179    /// Create new token usage.
180    #[must_use]
181    pub const fn new(input_tokens: u32, output_tokens: u32) -> Self {
182        Self {
183            input_tokens,
184            output_tokens,
185        }
186    }
187
188    /// Total tokens used.
189    #[must_use]
190    pub const fn total(&self) -> u32 {
191        self.input_tokens + self.output_tokens
192    }
193}
194
195/// Structured result from every provider call.
196///
197/// This is the core type that providers return. It includes:
198/// - The actual content
199/// - Provenance metadata for tracing
200/// - Cost and latency information
201///
202/// # Type Parameter
203///
204/// `T` is the content type (typically `String` for LLM responses).
205#[derive(Debug, Clone, Serialize, Deserialize)]
206pub struct ProviderObservation<T> {
207    /// Stable reference ID for this observation
208    pub observation_id: String,
209    /// Canonical hash of the request
210    pub request_hash: String,
211    /// Provider that produced this observation
212    pub vendor: String,
213    /// Model used
214    pub model: String,
215    /// Call latency in milliseconds
216    pub latency_ms: u64,
217    /// Estimated cost in USD (if known)
218    pub cost_estimate: Option<f64>,
219    /// Token usage (if applicable)
220    pub tokens: Option<TokenUsage>,
221    /// The actual content
222    pub content: T,
223    /// Raw response (optional, size-bounded)
224    #[serde(skip_serializing_if = "Option::is_none")]
225    pub raw_response: Option<String>,
226}
227
228impl<T> ProviderObservation<T> {
229    /// Create a new observation.
230    pub fn new(
231        vendor: impl Into<String>,
232        model: impl Into<String>,
233        content: T,
234        latency_ms: u64,
235    ) -> Self {
236        let observation_id = generate_observation_id();
237        Self {
238            observation_id,
239            request_hash: String::new(),
240            vendor: vendor.into(),
241            model: model.into(),
242            latency_ms,
243            cost_estimate: None,
244            tokens: None,
245            content,
246            raw_response: None,
247        }
248    }
249
250    /// Set the request hash.
251    #[must_use]
252    pub fn with_request_hash(mut self, hash: impl Into<String>) -> Self {
253        self.request_hash = hash.into();
254        self
255    }
256
257    /// Set the cost estimate.
258    #[must_use]
259    pub fn with_cost(mut self, cost: f64) -> Self {
260        self.cost_estimate = Some(cost);
261        self
262    }
263
264    /// Set the token usage.
265    #[must_use]
266    pub fn with_tokens(mut self, input: u32, output: u32) -> Self {
267        self.tokens = Some(TokenUsage::new(input, output));
268        self
269    }
270
271    /// Set the raw response (will be truncated if too long).
272    #[must_use]
273    pub fn with_raw_response(mut self, raw: impl Into<String>) -> Self {
274        let raw = raw.into();
275        const MAX_RAW_SIZE: usize = 10_000;
276        if raw.len() > MAX_RAW_SIZE {
277            self.raw_response = Some(format!("{}...[truncated]", &raw[..MAX_RAW_SIZE]));
278        } else {
279            self.raw_response = Some(raw);
280        }
281        self
282    }
283
284    /// Generate provenance string for Facts.
285    ///
286    /// This string can be attached to `ProposedFact` instances to trace
287    /// where the data came from.
288    pub fn provenance(&self) -> String {
289        format!("{}:{}:{}", self.vendor, self.model, self.observation_id)
290    }
291}
292
293/// A timer for measuring provider call latency.
294///
295/// Use this to accurately measure call duration.
296pub struct CallTimer {
297    start: Instant,
298}
299
300impl CallTimer {
301    /// Start a new timer.
302    #[must_use]
303    pub fn start() -> Self {
304        Self {
305            start: Instant::now(),
306        }
307    }
308
309    /// Get elapsed time in milliseconds.
310    #[must_use]
311    pub fn elapsed_ms(&self) -> u64 {
312        self.start.elapsed().as_millis() as u64
313    }
314}
315
316/// Compute a canonical hash for a request.
317///
318/// This creates a deterministic fingerprint of a request that can be used
319/// for caching and provenance tracking.
320#[must_use]
321pub fn canonical_hash(data: &str) -> String {
322    let mut hasher = DefaultHasher::new();
323    data.hash(&mut hasher);
324    format!("hash:{:016x}", hasher.finish())
325}
326
327/// Generate a unique observation ID.
328fn generate_observation_id() -> String {
329    use std::sync::atomic::{AtomicU64, Ordering};
330    static COUNTER: AtomicU64 = AtomicU64::new(0);
331
332    let count = COUNTER.fetch_add(1, Ordering::Relaxed);
333    let timestamp = std::time::SystemTime::now()
334        .duration_since(std::time::UNIX_EPOCH)
335        .map(|d| d.as_millis())
336        .unwrap_or(0);
337
338    format!("obs-{timestamp:x}-{count:x}")
339}
340
341/// Generate a trace ID.
342fn generate_trace_id() -> String {
343    use std::sync::atomic::{AtomicU64, Ordering};
344    static COUNTER: AtomicU64 = AtomicU64::new(0);
345
346    let count = COUNTER.fetch_add(1, Ordering::Relaxed);
347    let timestamp = std::time::SystemTime::now()
348        .duration_since(std::time::UNIX_EPOCH)
349        .map(|d| d.as_millis())
350        .unwrap_or(0);
351
352    format!("trace-{timestamp:x}-{count:x}")
353}
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358
359    #[test]
360    fn test_provider_meta() {
361        static CAPS: &[Capability] = &[Capability::LlmCompletion];
362        let meta = ProviderMeta::new("test", "1.0", CAPS, "test-vendor", Region::US);
363        assert_eq!(meta.name, "test");
364        assert_eq!(meta.region, Region::US);
365    }
366
367    #[test]
368    fn test_call_context_default() {
369        let ctx = ProviderCallContext::default();
370        assert_eq!(ctx.timeout_ms, 30_000);
371        assert!(ctx.trace_id.starts_with("trace-"));
372    }
373
374    #[test]
375    fn test_call_context_builder() {
376        let ctx = ProviderCallContext::default()
377            .with_root_intent("intent-123")
378            .with_user("user-456")
379            .with_timeout_ms(5000)
380            .with_max_cost(1.0)
381            .with_max_tokens(1000);
382
383        assert_eq!(ctx.root_intent_id, Some("intent-123".into()));
384        assert_eq!(ctx.user_id, Some("user-456".into()));
385        assert_eq!(ctx.timeout_ms, 5000);
386        assert_eq!(ctx.max_cost, Some(1.0));
387        assert_eq!(ctx.max_tokens, Some(1000));
388    }
389
390    #[test]
391    fn test_token_usage() {
392        let usage = TokenUsage::new(100, 50);
393        assert_eq!(usage.total(), 150);
394    }
395
396    #[test]
397    fn test_observation_provenance() {
398        let obs = ProviderObservation::new("anthropic", "claude-3", "content", 100);
399        let prov = obs.provenance();
400        assert!(prov.starts_with("anthropic:claude-3:obs-"));
401    }
402
403    #[test]
404    fn test_observation_builder() {
405        let obs = ProviderObservation::new("openai", "gpt-4", "response", 500)
406            .with_request_hash("hash:abc123")
407            .with_cost(0.05)
408            .with_tokens(100, 50);
409
410        assert_eq!(obs.request_hash, "hash:abc123");
411        assert_eq!(obs.cost_estimate, Some(0.05));
412        assert_eq!(obs.tokens.unwrap().total(), 150);
413    }
414
415    #[test]
416    fn test_raw_response_truncation() {
417        let long_response = "x".repeat(20_000);
418        let obs = ProviderObservation::new("test", "model", "content", 100)
419            .with_raw_response(long_response);
420
421        let raw = obs.raw_response.unwrap();
422        assert!(raw.ends_with("...[truncated]"));
423        assert!(raw.len() < 15_000);
424    }
425
426    #[test]
427    fn test_canonical_hash_deterministic() {
428        let hash1 = canonical_hash("test input");
429        let hash2 = canonical_hash("test input");
430        assert_eq!(hash1, hash2);
431
432        let hash3 = canonical_hash("different input");
433        assert_ne!(hash1, hash3);
434    }
435
436    #[test]
437    fn test_call_timer() {
438        let timer = CallTimer::start();
439        std::thread::sleep(std::time::Duration::from_millis(10));
440        let elapsed = timer.elapsed_ms();
441        assert!(elapsed >= 10);
442    }
443
444    #[test]
445    fn test_observation_ids_unique() {
446        let obs1 = ProviderObservation::new("test", "model", "a", 1);
447        let obs2 = ProviderObservation::new("test", "model", "b", 2);
448        assert_ne!(obs1.observation_id, obs2.observation_id);
449    }
450}