1use std::str::FromStr;
6use std::sync::Arc;
7
8use async_trait::async_trait;
9
10use cognis_core::{CognisError, Result, Runnable, RunnableConfig, RunnableStream};
11
12use crate::chat::{ChatOptions, ChatResponse, StreamChunk};
13use crate::provider::{LLMProvider, Provider};
14use crate::tools::ToolDefinition;
15use crate::Message;
16
17#[derive(Clone)]
19pub struct Client {
20 provider: Arc<dyn LLMProvider>,
21}
22
23impl std::fmt::Debug for Client {
24 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25 f.debug_struct("Client")
26 .field("provider", &self.provider.name())
27 .finish()
28 }
29}
30
31impl Client {
32 pub fn new(provider: Arc<dyn LLMProvider>) -> Self {
34 Self { provider }
35 }
36
37 pub fn builder() -> ClientBuilder {
39 ClientBuilder::default()
40 }
41
42 pub fn from_env() -> Result<Self> {
45 let provider_str = std::env::var("COGNIS_PROVIDER")
46 .map_err(|_| CognisError::Configuration("COGNIS_PROVIDER not set".into()))?;
47 let provider = Provider::from_str(&provider_str)?;
48 let mut b = Self::builder().provider(provider);
49
50 let key = std::env::var(format!(
51 "COGNIS_{}_API_KEY",
52 provider.to_string().to_uppercase()
53 ))
54 .or_else(|_| std::env::var("COGNIS_API_KEY"))
55 .ok();
56 if let Some(k) = key {
57 b = b.api_key(k);
58 }
59
60 let url = std::env::var(format!(
61 "COGNIS_{}_BASE_URL",
62 provider.to_string().to_uppercase()
63 ))
64 .or_else(|_| std::env::var("COGNIS_BASE_URL"))
65 .ok();
66 if let Some(u) = url {
67 b = b.base_url(u);
68 }
69
70 let model = std::env::var(format!(
71 "COGNIS_{}_MODEL",
72 provider.to_string().to_uppercase()
73 ))
74 .or_else(|_| std::env::var("COGNIS_MODEL"))
75 .ok();
76 if let Some(m) = model {
77 b = b.model(m);
78 }
79
80 b.build()
81 }
82
83 pub async fn invoke(&self, messages: Vec<Message>) -> Result<Message> {
85 Ok(self
86 .provider
87 .chat_completion(messages, ChatOptions::default())
88 .await?
89 .message)
90 }
91
92 pub async fn stream(&self, messages: Vec<Message>) -> Result<RunnableStream<StreamChunk>> {
94 self.provider
95 .chat_completion_stream(messages, ChatOptions::default())
96 .await
97 }
98
99 pub async fn invoke_with_tools(
101 &self,
102 messages: Vec<Message>,
103 tools: &[Arc<dyn crate::tools::Tool>],
104 ) -> Result<Message> {
105 let defs: Vec<ToolDefinition> = tools
106 .iter()
107 .map(|t| ToolDefinition::from_tool(t.as_ref()))
108 .collect();
109 Ok(self
110 .provider
111 .chat_completion_with_tools(messages, defs, ChatOptions::default())
112 .await?
113 .message)
114 }
115
116 pub async fn chat(&self, messages: Vec<Message>, opts: ChatOptions) -> Result<ChatResponse> {
118 self.provider.chat_completion(messages, opts).await
119 }
120
121 pub fn provider(&self) -> &Arc<dyn LLMProvider> {
123 &self.provider
124 }
125}
126
127#[async_trait]
128impl Runnable<Vec<Message>, Message> for Client {
129 async fn invoke(&self, input: Vec<Message>, _: RunnableConfig) -> Result<Message> {
130 Client::invoke(self, input).await
131 }
132 fn name(&self) -> &str {
133 "Client"
134 }
135}
136
137#[derive(Default)]
139pub struct ClientBuilder {
140 provider: Option<Provider>,
141 api_key: Option<String>,
142 base_url: Option<String>,
143 model: Option<String>,
144 timeout_secs: Option<u64>,
145 organization: Option<String>,
146 azure_endpoint: Option<String>,
147 azure_deployment: Option<String>,
148 azure_api_version: Option<String>,
149}
150
151impl ClientBuilder {
152 pub fn provider(mut self, p: Provider) -> Self {
154 self.provider = Some(p);
155 self
156 }
157 pub fn api_key(mut self, k: impl Into<String>) -> Self {
159 self.api_key = Some(k.into());
160 self
161 }
162 pub fn base_url(mut self, u: impl Into<String>) -> Self {
164 self.base_url = Some(u.into());
165 self
166 }
167 pub fn model(mut self, m: impl Into<String>) -> Self {
169 self.model = Some(m.into());
170 self
171 }
172 pub fn timeout_secs(mut self, s: u64) -> Self {
174 self.timeout_secs = Some(s);
175 self
176 }
177 pub fn organization(mut self, o: impl Into<String>) -> Self {
179 self.organization = Some(o.into());
180 self
181 }
182 pub fn azure_endpoint(mut self, e: impl Into<String>) -> Self {
184 self.azure_endpoint = Some(e.into());
185 self
186 }
187 pub fn azure_deployment(mut self, d: impl Into<String>) -> Self {
189 self.azure_deployment = Some(d.into());
190 self
191 }
192 pub fn azure_api_version(mut self, v: impl Into<String>) -> Self {
194 self.azure_api_version = Some(v.into());
195 self
196 }
197 pub fn build(self) -> Result<Client> {
199 let provider = self
200 .provider
201 .ok_or_else(|| CognisError::Configuration("Client: provider required".into()))?;
202 let arc_provider: Arc<dyn LLMProvider> = match provider {
203 #[cfg(feature = "openai")]
204 Provider::OpenAI => {
205 use crate::provider::openai::OpenAIBuilder;
206 let mut b = OpenAIBuilder::default();
207 if let Some(k) = self.api_key {
208 b = b.api_key(k);
209 }
210 if let Some(u) = self.base_url {
211 b = b.base_url(u);
212 }
213 if let Some(m) = self.model {
214 b = b.model(m);
215 }
216 if let Some(t) = self.timeout_secs {
217 b = b.timeout_secs(t);
218 }
219 if let Some(o) = self.organization {
220 b = b.organization(o);
221 }
222 Arc::new(b.build()?)
223 }
224 #[cfg(feature = "openai")]
225 Provider::OpenRouter => {
226 use crate::provider::openai::OpenAIBuilder;
227 let mut b = OpenAIBuilder::default()
228 .base_url(Provider::OpenRouter.default_base_url())
229 .model(Provider::OpenRouter.default_model());
230 if let Some(k) = self.api_key {
231 b = b.api_key(k);
232 }
233 if let Some(u) = self.base_url {
234 b = b.base_url(u);
235 }
236 if let Some(m) = self.model {
237 b = b.model(m);
238 }
239 if let Some(t) = self.timeout_secs {
240 b = b.timeout_secs(t);
241 }
242 Arc::new(b.build()?)
243 }
244 #[cfg(feature = "ollama")]
245 Provider::Ollama => {
246 use crate::provider::ollama::OllamaBuilder;
247 let mut b = OllamaBuilder::default();
248 if let Some(u) = self.base_url {
249 b = b.base_url(u);
250 }
251 if let Some(m) = self.model {
252 b = b.model(m);
253 }
254 if let Some(t) = self.timeout_secs {
255 b = b.timeout_secs(t);
256 }
257 Arc::new(b.build()?)
258 }
259 #[cfg(feature = "anthropic")]
260 Provider::Anthropic => {
261 use crate::provider::anthropic::AnthropicBuilder;
262 let mut b = AnthropicBuilder::default();
263 if let Some(k) = self.api_key {
264 b = b.api_key(k);
265 }
266 if let Some(u) = self.base_url {
267 b = b.base_url(u);
268 }
269 if let Some(m) = self.model {
270 b = b.model(m);
271 }
272 if let Some(t) = self.timeout_secs {
273 b = b.timeout_secs(t);
274 }
275 Arc::new(b.build()?)
276 }
277 #[cfg(feature = "google")]
278 Provider::Google => {
279 use crate::provider::google::GoogleBuilder;
280 let mut b = GoogleBuilder::default();
281 if let Some(k) = self.api_key {
282 b = b.api_key(k);
283 }
284 if let Some(u) = self.base_url {
285 b = b.base_url(u);
286 }
287 if let Some(m) = self.model {
288 b = b.model(m);
289 }
290 if let Some(t) = self.timeout_secs {
291 b = b.timeout_secs(t);
292 }
293 Arc::new(b.build()?)
294 }
295 #[cfg(feature = "azure")]
296 Provider::Azure => {
297 use crate::provider::azure::AzureBuilder;
298 let mut b = AzureBuilder::default();
299 if let Some(e) = self.azure_endpoint {
300 b = b.endpoint(e);
301 }
302 if let Some(d) = self.azure_deployment {
303 b = b.deployment(d);
304 }
305 if let Some(v) = self.azure_api_version {
306 b = b.api_version(v);
307 }
308 if let Some(k) = self.api_key {
309 b = b.api_key(k);
310 }
311 if let Some(t) = self.timeout_secs {
312 b = b.timeout_secs(t);
313 }
314 Arc::new(b.build()?)
315 }
316 #[allow(unreachable_patterns)]
317 other => {
318 return Err(CognisError::Configuration(format!(
319 "provider `{other}` not compiled in (enable the matching feature flag)"
320 )))
321 }
322 };
323 Ok(Client {
324 provider: arc_provider,
325 })
326 }
327}
328
329#[cfg(test)]
330mod tests {
331 use super::*;
332
333 #[cfg(feature = "openai")]
334 #[test]
335 fn openai_builder_round_trip() {
336 let c = ClientBuilder::default()
337 .provider(Provider::OpenAI)
338 .api_key("sk-test")
339 .model("gpt-4o")
340 .build()
341 .unwrap();
342 assert_eq!(c.provider().name(), "openai");
343 }
344
345 #[cfg(feature = "ollama")]
346 #[test]
347 fn ollama_builder_round_trip() {
348 let c = ClientBuilder::default()
349 .provider(Provider::Ollama)
350 .model("llama3.2")
351 .build()
352 .unwrap();
353 assert_eq!(c.provider().name(), "ollama");
354 }
355
356 #[test]
357 fn missing_provider_errors() {
358 let err = ClientBuilder::default().build().unwrap_err();
359 assert!(format!("{err}").contains("provider required"));
360 }
361
362 #[cfg(feature = "anthropic")]
363 #[test]
364 fn anthropic_builder_round_trip() {
365 let c = ClientBuilder::default()
366 .provider(Provider::Anthropic)
367 .api_key("sk-ant-test")
368 .build()
369 .unwrap();
370 assert_eq!(c.provider().name(), "anthropic");
371 }
372
373 #[cfg(feature = "google")]
374 #[test]
375 fn google_builder_round_trip() {
376 let c = ClientBuilder::default()
377 .provider(Provider::Google)
378 .api_key("AIza-test")
379 .build()
380 .unwrap();
381 assert_eq!(c.provider().name(), "google");
382 }
383
384 #[cfg(feature = "azure")]
385 #[test]
386 fn azure_builder_round_trip() {
387 let c = ClientBuilder::default()
388 .provider(Provider::Azure)
389 .azure_endpoint("https://r.openai.azure.com/")
390 .azure_deployment("gpt-4o")
391 .api_key("k")
392 .build()
393 .unwrap();
394 assert_eq!(c.provider().name(), "azure");
395 }
396}