use crate::client::core::AiClient;
use crate::feedback::FeedbackSink;
use crate::protocol::ProtocolLoader;
use crate::Result;
use std::sync::atomic::AtomicU64;
use std::sync::Arc;
use tokio::sync::Semaphore;
pub struct AiClientBuilder {
protocol_path: Option<String>,
hot_reload: bool,
fallbacks: Vec<String>,
strict_streaming: bool,
feedback: Arc<dyn FeedbackSink>,
max_inflight: Option<usize>,
base_url_override: Option<String>,
credential_override: Option<String>,
}
impl AiClientBuilder {
pub fn new() -> Self {
Self {
protocol_path: None,
hot_reload: false,
fallbacks: Vec::new(),
strict_streaming: false,
feedback: crate::feedback::noop_sink(),
max_inflight: None,
base_url_override: None,
credential_override: None,
}
}
pub fn protocol_path(mut self, path: String) -> Self {
self.protocol_path = Some(path);
self
}
pub fn hot_reload(mut self, enable: bool) -> Self {
self.hot_reload = enable;
self
}
pub fn with_fallbacks(mut self, fallbacks: Vec<String>) -> Self {
self.fallbacks = fallbacks;
self
}
pub fn strict_streaming(mut self, enable: bool) -> Self {
self.strict_streaming = enable;
self
}
pub fn feedback_sink(mut self, sink: Arc<dyn FeedbackSink>) -> Self {
self.feedback = sink;
self
}
pub fn max_inflight(mut self, n: usize) -> Self {
self.max_inflight = Some(n.max(1));
self
}
pub fn base_url_override(mut self, base_url: impl Into<String>) -> Self {
self.base_url_override = Some(base_url.into());
self
}
pub fn credential(mut self, credential: impl Into<String>) -> Self {
self.credential_override = Some(credential.into());
self
}
pub fn api_key(self, api_key: impl Into<String>) -> Self {
self.credential(api_key)
}
pub async fn build(self, model: &str) -> Result<AiClient> {
let mut loader = ProtocolLoader::new();
if let Some(path) = self.protocol_path {
loader = loader.with_base_path(path);
}
if self.hot_reload {
loader = loader.with_hot_reload(true);
}
let parts: Vec<&str> = model.split('/').collect();
let model_id = if parts.len() >= 2 {
parts[1..].join("/")
} else {
model.to_string()
};
let manifest = loader.load_model(model).await?;
let strict_streaming = self.strict_streaming
|| std::env::var("AI_LIB_STRICT_STREAMING").ok().as_deref() == Some("1");
crate::client::validation::validate_manifest(&manifest, strict_streaming)?;
let base_url_override = self
.base_url_override
.or_else(|| std::env::var("MOCK_HTTP_URL").ok());
let transport = Arc::new(
crate::transport::HttpTransport::new_with_base_url_and_credential(
&manifest,
&model_id,
base_url_override.as_deref(),
self.credential_override.as_deref(),
)?,
);
let pipeline = Arc::new(crate::pipeline::Pipeline::from_manifest(&manifest)?);
let max_inflight = self.max_inflight.or_else(|| {
std::env::var("AI_LIB_MAX_INFLIGHT")
.ok()?
.parse::<usize>()
.ok()
});
let inflight = max_inflight.map(|n| Arc::new(Semaphore::new(n.max(1))));
let attempt_timeout = std::env::var("AI_LIB_ATTEMPT_TIMEOUT_MS")
.ok()
.and_then(|s| s.parse::<u64>().ok())
.filter(|ms| *ms > 0)
.map(std::time::Duration::from_millis);
Ok(AiClient {
manifest,
transport,
pipeline,
loader: Arc::new(loader),
fallbacks: self.fallbacks,
model_id,
strict_streaming,
feedback: self.feedback,
inflight,
max_inflight,
credential_override: self.credential_override,
attempt_timeout,
total_requests: AtomicU64::new(0),
successful_requests: AtomicU64::new(0),
total_tokens: AtomicU64::new(0),
})
}
}
impl Default for AiClientBuilder {
fn default() -> Self {
Self::new()
}
}