crabtalk_model/
provider.rs1use crate::{config::ProviderDef, convert};
7use anyhow::Result;
8use async_stream::try_stream;
9use crabllm_provider::Provider as CtProvider;
10use futures_core::Stream;
11use futures_util::StreamExt;
12use rand::Rng;
13use std::time::Duration;
14use wcore::model::{Model, Response, StreamChunk};
15
16#[derive(Clone)]
18pub struct Provider {
19 inner: CtProvider,
20 client: reqwest::Client,
21 model: String,
22 max_retries: u32,
23 timeout: Duration,
24}
25
26impl Provider {
27 pub fn model_name(&self) -> &String {
29 &self.model
30 }
31}
32
33fn normalize_base_url(url: &str) -> String {
35 let url = url.trim_end_matches('/');
36 for suffix in ["/chat/completions", "/messages", "/embeddings"] {
37 if let Some(stripped) = url.strip_suffix(suffix) {
38 return stripped.to_string();
39 }
40 }
41 url.to_string()
42}
43
44pub fn build_provider(def: &ProviderDef, model: &str, client: reqwest::Client) -> Result<Provider> {
46 let mut config = def.clone();
47 config.kind = config.effective_kind();
48 let mut inner = CtProvider::from(&config);
49
50 if let CtProvider::OpenAiCompat {
52 ref mut base_url, ..
53 } = inner
54 {
55 *base_url = normalize_base_url(base_url);
56 }
57
58 Ok(Provider {
59 inner,
60 client,
61 model: model.to_owned(),
62 max_retries: def.max_retries.unwrap_or(2),
63 timeout: Duration::from_secs(def.timeout.unwrap_or(30)),
64 })
65}
66
67impl Model for Provider {
68 async fn send(&self, request: &wcore::model::Request) -> Result<Response> {
69 let mut ct_req = convert::to_ct_request(request);
70 ct_req.stream = Some(false);
71 send_with_retry(
72 &self.inner,
73 &self.client,
74 &ct_req,
75 self.max_retries,
76 self.timeout,
77 )
78 .await
79 }
80
81 fn stream(
82 &self,
83 request: wcore::model::Request,
84 ) -> impl Stream<Item = Result<StreamChunk>> + Send {
85 let inner = self.inner.clone();
86 let client = self.client.clone();
87 let timeout = self.timeout;
88 try_stream! {
89 let mut ct_req = convert::to_ct_request(&request);
90 ct_req.stream = Some(true);
91
92 let boxed = tokio::time::timeout(timeout, inner.chat_completion_stream(&client, &ct_req))
93 .await
94 .map_err(|_| anyhow::anyhow!("stream connection timed out"))?
95 .map_err(|e| anyhow::anyhow!("{e}"))?;
96
97 let mut stream = std::pin::pin!(boxed);
98 while let Some(chunk) = stream.next().await {
99 let ct_chunk = chunk.map_err(|e| anyhow::anyhow!("{e}"))?;
100 yield convert::from_ct_chunk(ct_chunk);
101 }
102 }
103 }
104
105 fn context_limit(&self, model: &str) -> usize {
106 wcore::model::default_context_limit(model)
107 }
108
109 fn active_model(&self) -> String {
110 self.model.clone()
111 }
112}
113
114async fn send_with_retry(
116 provider: &CtProvider,
117 client: &reqwest::Client,
118 request: &crabllm_core::ChatCompletionRequest,
119 max_retries: u32,
120 timeout: Duration,
121) -> Result<Response> {
122 let mut backoff = Duration::from_millis(100);
123 let mut last_err = None;
124
125 for _ in 0..=max_retries {
126 let result = if timeout.is_zero() {
127 provider.chat_completion(client, request).await
128 } else {
129 tokio::time::timeout(timeout, provider.chat_completion(client, request))
130 .await
131 .map_err(|_| crabllm_core::Error::Timeout)?
132 };
133
134 match result {
135 Ok(resp) => return Ok(convert::from_ct_response(resp)),
136 Err(e) if e.is_transient() => {
137 last_err = Some(e);
138 let jitter = jittered(backoff);
139 tokio::time::sleep(jitter).await;
140 backoff *= 2;
141 }
142 Err(e) => return Err(anyhow::anyhow!("{e}")),
143 }
144 }
145
146 Err(anyhow::anyhow!("{}", last_err.unwrap()))
147}
148
149fn jittered(backoff: Duration) -> Duration {
151 let lo = backoff.as_millis() as u64 / 2;
152 let hi = backoff.as_millis() as u64;
153 if lo >= hi {
154 return backoff;
155 }
156 Duration::from_millis(rand::rng().random_range(lo..=hi))
157}