Skip to main content

ai_lib_core/client/
builder.rs

1use crate::client::core::AiClient;
2use crate::feedback::FeedbackSink;
3use crate::protocol::ProtocolLoader;
4use crate::Result;
5use std::sync::atomic::AtomicU64;
6use std::sync::Arc;
7use tokio::sync::Semaphore;
8
9/// Builder for creating clients with custom configuration.
10///
11/// Keep this surface area small and predictable (developer-friendly).
12///
13/// ## Sharing across tasks
14///
15/// `AiClient` does not implement `Clone` (by design, for API key and ToS compliance).
16/// To share a client across multiple async tasks, wrap it in `Arc`:
17///
18/// ```ignore
19/// let client = Arc::new(
20///     AiClientBuilder::new()
21///         .build("openai/gpt-4o")
22///         .await?
23/// );
24/// // Use Arc::clone(&client) to pass to tasks
25/// tokio::spawn(use_client(Arc::clone(&client)));
26/// ```
27pub struct AiClientBuilder {
28    protocol_path: Option<String>,
29    hot_reload: bool,
30    fallbacks: Vec<String>,
31    strict_streaming: bool,
32    feedback: Arc<dyn FeedbackSink>,
33    max_inflight: Option<usize>,
34    /// Override base URL (primarily for testing with mock servers)
35    base_url_override: Option<String>,
36}
37
38impl AiClientBuilder {
39    pub fn new() -> Self {
40        Self {
41            protocol_path: None,
42            hot_reload: false,
43            fallbacks: Vec::new(),
44            strict_streaming: false,
45            feedback: crate::feedback::noop_sink(),
46            max_inflight: None,
47            base_url_override: None,
48        }
49    }
50
51    /// Set custom protocol directory path.
52    pub fn protocol_path(mut self, path: String) -> Self {
53        self.protocol_path = Some(path);
54        self
55    }
56
57    /// Enable hot reload of protocol files.
58    pub fn hot_reload(mut self, enable: bool) -> Self {
59        self.hot_reload = enable;
60        self
61    }
62
63    /// Set fallback models.
64    pub fn with_fallbacks(mut self, fallbacks: Vec<String>) -> Self {
65        self.fallbacks = fallbacks;
66        self
67    }
68
69    /// Enable strict streaming validation (fail fast when streaming config is incomplete).
70    ///
71    /// This is intentionally opt-in to preserve compatibility with partial manifests.
72    pub fn strict_streaming(mut self, enable: bool) -> Self {
73        self.strict_streaming = enable;
74        self
75    }
76
77    /// Inject a feedback sink. Default is a no-op sink.
78    pub fn feedback_sink(mut self, sink: Arc<dyn FeedbackSink>) -> Self {
79        self.feedback = sink;
80        self
81    }
82
83    /// Limit maximum number of in-flight requests/streams.
84    /// This is a simple backpressure mechanism for production safety.
85    pub fn max_inflight(mut self, n: usize) -> Self {
86        self.max_inflight = Some(n.max(1));
87        self
88    }
89
90    /// Override the base URL from the protocol manifest.
91    ///
92    /// This is primarily for testing with mock servers. In production, use the
93    /// base_url defined in the protocol manifest.
94    pub fn base_url_override(mut self, base_url: impl Into<String>) -> Self {
95        self.base_url_override = Some(base_url.into());
96        self
97    }
98
99    /// Build the client.
100    pub async fn build(self, model: &str) -> Result<AiClient> {
101        let mut loader = ProtocolLoader::new();
102
103        if let Some(path) = self.protocol_path {
104            loader = loader.with_base_path(path);
105        }
106
107        if self.hot_reload {
108            loader = loader.with_hot_reload(true);
109        }
110
111        // model is in form "provider/model-id" or "provider/org/model-name" (e.g. nvidia/minimaxai/minimax-m2)
112        let parts: Vec<&str> = model.split('/').collect();
113        let model_id = if parts.len() >= 2 {
114            parts[1..].join("/")
115        } else {
116            model.to_string()
117        };
118
119        let manifest = loader.load_model(model).await?;
120        let strict_streaming = self.strict_streaming
121            || std::env::var("AI_LIB_STRICT_STREAMING").ok().as_deref() == Some("1");
122        crate::client::validation::validate_manifest(&manifest, strict_streaming)?;
123
124        // Use MOCK_HTTP_URL env var when base_url_override not set (for testing with ai-protocol-mock)
125        let base_url_override = self
126            .base_url_override
127            .or_else(|| std::env::var("MOCK_HTTP_URL").ok());
128
129        let transport = Arc::new(crate::transport::HttpTransport::new_with_base_url(
130            &manifest,
131            &model_id,
132            base_url_override.as_deref(),
133        )?);
134        let pipeline = Arc::new(crate::pipeline::Pipeline::from_manifest(&manifest)?);
135
136        let max_inflight = self.max_inflight.or_else(|| {
137            std::env::var("AI_LIB_MAX_INFLIGHT")
138                .ok()?
139                .parse::<usize>()
140                .ok()
141        });
142        let inflight = max_inflight.map(|n| Arc::new(Semaphore::new(n.max(1))));
143
144        // Optional per-attempt timeout (policy signal). Transport has its own timeout too; this is an extra guard.
145        let attempt_timeout = std::env::var("AI_LIB_ATTEMPT_TIMEOUT_MS")
146            .ok()
147            .and_then(|s| s.parse::<u64>().ok())
148            .filter(|ms| *ms > 0)
149            .map(std::time::Duration::from_millis);
150
151        Ok(AiClient {
152            manifest,
153            transport,
154            pipeline,
155            loader: Arc::new(loader),
156            fallbacks: self.fallbacks,
157            model_id,
158            strict_streaming,
159            feedback: self.feedback,
160            inflight,
161            max_inflight,
162            attempt_timeout,
163            total_requests: AtomicU64::new(0),
164            successful_requests: AtomicU64::new(0),
165            total_tokens: AtomicU64::new(0),
166        })
167    }
168}
169
170impl Default for AiClientBuilder {
171    fn default() -> Self {
172        Self::new()
173    }
174}