ai_lib_core/client/
builder.rs1use crate::client::core::AiClient;
2use crate::feedback::FeedbackSink;
3use crate::protocol::ProtocolLoader;
4use crate::Result;
5use std::sync::atomic::AtomicU64;
6use std::sync::Arc;
7use tokio::sync::Semaphore;
8
9pub struct AiClientBuilder {
28 protocol_path: Option<String>,
29 hot_reload: bool,
30 fallbacks: Vec<String>,
31 strict_streaming: bool,
32 feedback: Arc<dyn FeedbackSink>,
33 max_inflight: Option<usize>,
34 base_url_override: Option<String>,
36}
37
38impl AiClientBuilder {
39 pub fn new() -> Self {
40 Self {
41 protocol_path: None,
42 hot_reload: false,
43 fallbacks: Vec::new(),
44 strict_streaming: false,
45 feedback: crate::feedback::noop_sink(),
46 max_inflight: None,
47 base_url_override: None,
48 }
49 }
50
51 pub fn protocol_path(mut self, path: String) -> Self {
53 self.protocol_path = Some(path);
54 self
55 }
56
57 pub fn hot_reload(mut self, enable: bool) -> Self {
59 self.hot_reload = enable;
60 self
61 }
62
63 pub fn with_fallbacks(mut self, fallbacks: Vec<String>) -> Self {
65 self.fallbacks = fallbacks;
66 self
67 }
68
69 pub fn strict_streaming(mut self, enable: bool) -> Self {
73 self.strict_streaming = enable;
74 self
75 }
76
77 pub fn feedback_sink(mut self, sink: Arc<dyn FeedbackSink>) -> Self {
79 self.feedback = sink;
80 self
81 }
82
83 pub fn max_inflight(mut self, n: usize) -> Self {
86 self.max_inflight = Some(n.max(1));
87 self
88 }
89
90 pub fn base_url_override(mut self, base_url: impl Into<String>) -> Self {
95 self.base_url_override = Some(base_url.into());
96 self
97 }
98
99 pub async fn build(self, model: &str) -> Result<AiClient> {
101 let mut loader = ProtocolLoader::new();
102
103 if let Some(path) = self.protocol_path {
104 loader = loader.with_base_path(path);
105 }
106
107 if self.hot_reload {
108 loader = loader.with_hot_reload(true);
109 }
110
111 let parts: Vec<&str> = model.split('/').collect();
113 let model_id = if parts.len() >= 2 {
114 parts[1..].join("/")
115 } else {
116 model.to_string()
117 };
118
119 let manifest = loader.load_model(model).await?;
120 let strict_streaming = self.strict_streaming
121 || std::env::var("AI_LIB_STRICT_STREAMING").ok().as_deref() == Some("1");
122 crate::client::validation::validate_manifest(&manifest, strict_streaming)?;
123
124 let base_url_override = self
126 .base_url_override
127 .or_else(|| std::env::var("MOCK_HTTP_URL").ok());
128
129 let transport = Arc::new(crate::transport::HttpTransport::new_with_base_url(
130 &manifest,
131 &model_id,
132 base_url_override.as_deref(),
133 )?);
134 let pipeline = Arc::new(crate::pipeline::Pipeline::from_manifest(&manifest)?);
135
136 let max_inflight = self.max_inflight.or_else(|| {
137 std::env::var("AI_LIB_MAX_INFLIGHT")
138 .ok()?
139 .parse::<usize>()
140 .ok()
141 });
142 let inflight = max_inflight.map(|n| Arc::new(Semaphore::new(n.max(1))));
143
144 let attempt_timeout = std::env::var("AI_LIB_ATTEMPT_TIMEOUT_MS")
146 .ok()
147 .and_then(|s| s.parse::<u64>().ok())
148 .filter(|ms| *ms > 0)
149 .map(std::time::Duration::from_millis);
150
151 Ok(AiClient {
152 manifest,
153 transport,
154 pipeline,
155 loader: Arc::new(loader),
156 fallbacks: self.fallbacks,
157 model_id,
158 strict_streaming,
159 feedback: self.feedback,
160 inflight,
161 max_inflight,
162 attempt_timeout,
163 total_requests: AtomicU64::new(0),
164 successful_requests: AtomicU64::new(0),
165 total_tokens: AtomicU64::new(0),
166 })
167 }
168}
169
170impl Default for AiClientBuilder {
171 fn default() -> Self {
172 Self::new()
173 }
174}