crabtalk_model/
provider.rs1use crate::{config::ProviderDef, convert};
7use anyhow::Result;
8use async_stream::try_stream;
9use crabllm_core::ApiError;
10use crabllm_provider::Provider as CtProvider;
11use futures_core::Stream;
12use futures_util::StreamExt;
13use rand::Rng;
14use std::time::Duration;
15use wcore::model::{Model, Response, StreamChunk};
16
17#[derive(Clone)]
19pub struct Provider {
20 inner: CtProvider,
21 client: reqwest::Client,
22 model: String,
23 max_retries: u32,
24 timeout: Duration,
25}
26
27impl Provider {
28 pub fn model_name(&self) -> &String {
30 &self.model
31 }
32}
33
34fn normalize_base_url(url: &str) -> String {
36 let url = url.trim_end_matches('/');
37 for suffix in ["/chat/completions", "/messages", "/embeddings"] {
38 if let Some(stripped) = url.strip_suffix(suffix) {
39 return stripped.to_string();
40 }
41 }
42 url.to_string()
43}
44
45pub fn build_provider(def: &ProviderDef, model: &str, client: reqwest::Client) -> Result<Provider> {
47 let mut config = def.clone();
48 config.kind = config.effective_kind();
49 let mut inner = CtProvider::from(&config);
50
51 if let CtProvider::OpenAiCompat {
53 ref mut base_url, ..
54 } = inner
55 {
56 *base_url = normalize_base_url(base_url);
57 }
58
59 Ok(Provider {
60 inner,
61 client,
62 model: model.to_owned(),
63 max_retries: def.max_retries.unwrap_or(2),
64 timeout: Duration::from_secs(def.timeout.unwrap_or(30)),
65 })
66}
67
68impl Model for Provider {
69 async fn send(&self, request: &wcore::model::Request) -> Result<Response> {
70 let mut ct_req = convert::to_ct_request(request);
71 ct_req.stream = Some(false);
72 send_with_retry(
73 &self.inner,
74 &self.client,
75 &ct_req,
76 self.max_retries,
77 self.timeout,
78 )
79 .await
80 }
81
82 fn stream(
83 &self,
84 request: wcore::model::Request,
85 ) -> impl Stream<Item = Result<StreamChunk>> + Send {
86 let inner = self.inner.clone();
87 let client = self.client.clone();
88 let timeout = self.timeout;
89 try_stream! {
90 let mut ct_req = convert::to_ct_request(&request);
91 ct_req.stream = Some(true);
92
93 let boxed = tokio::time::timeout(timeout, inner.chat_completion_stream(&client, &ct_req))
94 .await
95 .map_err(|_| anyhow::anyhow!("stream connection timed out"))?
96 .map_err(format_provider_error)?;
97
98 let mut stream = std::pin::pin!(boxed);
99 while let Some(chunk) = stream.next().await {
100 let ct_chunk = chunk.map_err(format_provider_error)?;
101 yield convert::from_ct_chunk(ct_chunk);
102 }
103 }
104 }
105
106 fn context_limit(&self, model: &str) -> usize {
107 wcore::model::default_context_limit(model)
108 }
109
110 fn active_model(&self) -> String {
111 self.model.clone()
112 }
113}
114
115async fn send_with_retry(
117 provider: &CtProvider,
118 client: &reqwest::Client,
119 request: &crabllm_core::ChatCompletionRequest,
120 max_retries: u32,
121 timeout: Duration,
122) -> Result<Response> {
123 let mut backoff = Duration::from_millis(100);
124 let mut last_err = None;
125
126 for _ in 0..=max_retries {
127 let result = if timeout.is_zero() {
128 provider.chat_completion(client, request).await
129 } else {
130 tokio::time::timeout(timeout, provider.chat_completion(client, request))
131 .await
132 .map_err(|_| crabllm_core::Error::Timeout)?
133 };
134
135 match result {
136 Ok(resp) => return Ok(convert::from_ct_response(resp)),
137 Err(e) if e.is_transient() => {
138 last_err = Some(e);
139 let jitter = jittered(backoff);
140 tokio::time::sleep(jitter).await;
141 backoff *= 2;
142 }
143 Err(e) => return Err(format_provider_error(e)),
144 }
145 }
146
147 Err(format_provider_error(last_err.unwrap()))
148}
149
150fn jittered(backoff: Duration) -> Duration {
152 let lo = backoff.as_millis() as u64 / 2;
153 let hi = backoff.as_millis() as u64;
154 if lo >= hi {
155 return backoff;
156 }
157 Duration::from_millis(rand::rng().random_range(lo..=hi))
158}
159
160fn format_provider_error(e: crabllm_core::Error) -> anyhow::Error {
165 match e {
166 crabllm_core::Error::Provider { status, body } => {
167 let msg = serde_json::from_str::<ApiError>(&body)
168 .map(|api_err| api_err.error.message)
169 .unwrap_or_else(|_| truncate(&body, 200));
170 anyhow::anyhow!("provider error (HTTP {status}): {msg}")
171 }
172 other => anyhow::anyhow!("{other}"),
173 }
174}
175
176fn truncate(s: &str, max: usize) -> String {
177 match s.char_indices().nth(max) {
178 Some((i, _)) => format!("{}...", &s[..i]),
179 None => s.to_string(),
180 }
181}