1pub mod config;
25pub mod error;
26pub mod models;
27pub mod rate_limit;
28pub mod request;
29pub mod response;
30pub mod version;
31
32#[cfg(feature = "batch")]
33pub mod batch;
34#[cfg(feature = "compression")]
35pub mod compression;
36#[cfg(feature = "batch")]
37pub mod files;
38
39mod sse;
40mod stream;
41
42use std::collections::VecDeque;
43use std::sync::{Arc, Mutex};
44
45use agentkit_core::TurnCancellation;
46use agentkit_http::{BodyStream, Http, HttpError, HttpRequestBuilder};
47use agentkit_loop::{
48 LoopError, ModelAdapter, ModelSession, ModelTurn, ModelTurnEvent, SessionConfig, TurnRequest,
49};
50use async_trait::async_trait;
51use futures_util::StreamExt;
52use futures_util::future::{Either, select};
53
54pub use crate::config::{
55 CerebrasConfig, DEFAULT_BASE_URL, DEFAULT_VERSION_PATCH, OutputFormat, PartKindName,
56 ReasoningConfig, ReasoningEffort, ReasoningFormat, ToolChoice,
57};
58pub use crate::error::{BuildError, CerebrasError, ResponseError};
59pub use crate::models::{ModelObject, ModelsClient};
60pub use crate::rate_limit::RateLimitSnapshot;
61
62#[cfg(feature = "predicted-outputs")]
63pub use crate::config::Prediction;
64#[cfg(feature = "compression")]
65pub use crate::config::{CompressionConfig, RequestEncoding};
66#[cfg(feature = "service-tiers")]
67pub use crate::config::{QueueThreshold, ServiceTier};
68
69#[cfg(feature = "batch")]
70pub use crate::batch::{
71 BatchClient, BatchItem, BatchJob, BatchOutcome, BatchRequestCounts, BatchStatus, ChatOverrides,
72};
73#[cfg(feature = "batch")]
74pub use crate::files::{FileObject, FilePurpose, FilesClient};
75
76use crate::stream::{EventTranslator, SseDecoder};
77
78#[derive(Clone)]
81pub struct CerebrasAdapter {
82 client: Http,
83 config: Arc<CerebrasConfig>,
84 last_rate_limit: Arc<Mutex<Option<RateLimitSnapshot>>>,
85}
86
87impl CerebrasAdapter {
88 pub fn new(config: CerebrasConfig) -> Result<Self, CerebrasError> {
91 config.validate()?;
92 let client = reqwest::Client::builder()
93 .build()
94 .map(Http::new)
95 .map_err(|error| CerebrasError::Http(HttpError::request(error)))?;
96 Ok(Self {
97 client,
98 config: Arc::new(config),
99 last_rate_limit: Arc::new(Mutex::new(None)),
100 })
101 }
102
103 pub fn with_client(config: CerebrasConfig, client: Http) -> Result<Self, CerebrasError> {
105 config.validate()?;
106 Ok(Self {
107 client,
108 config: Arc::new(config),
109 last_rate_limit: Arc::new(Mutex::new(None)),
110 })
111 }
112
113 pub fn last_rate_limit(&self) -> Option<RateLimitSnapshot> {
115 self.last_rate_limit.lock().ok()?.clone()
116 }
117
118 pub fn models(&self) -> ModelsClient<'_> {
120 ModelsClient::new(&self.client, self.config.clone())
121 }
122
123 #[cfg(feature = "batch")]
125 pub fn batches(&self) -> BatchClient<'_> {
126 BatchClient::new(&self.client, self.config.clone())
127 }
128
129 #[cfg(feature = "batch")]
131 pub fn files(&self) -> FilesClient<'_> {
132 FilesClient::new(&self.client, self.config.clone())
133 }
134}
135
136pub struct CerebrasSession {
138 client: Http,
139 config: Arc<CerebrasConfig>,
140 rate_limit_slot: Arc<Mutex<Option<RateLimitSnapshot>>>,
141 _session_config: SessionConfig,
142}
143
144pub struct CerebrasTurn {
146 inner: TurnInner,
147}
148
149enum TurnInner {
150 Buffered { events: VecDeque<ModelTurnEvent> },
151 Streaming(Box<StreamingState>),
152}
153
154struct StreamingState {
155 body: BodyStream,
156 decoder: SseDecoder,
157 translator: EventTranslator,
158 pending: VecDeque<ModelTurnEvent>,
159 eof: bool,
160}
161
162#[async_trait]
163impl ModelAdapter for CerebrasAdapter {
164 type Session = CerebrasSession;
165
166 async fn start_session(&self, config: SessionConfig) -> Result<Self::Session, LoopError> {
167 Ok(CerebrasSession {
168 client: self.client.clone(),
169 config: self.config.clone(),
170 rate_limit_slot: self.last_rate_limit.clone(),
171 _session_config: config,
172 })
173 }
174
175 fn provider_name(&self) -> Option<&str> {
176 Some("cerebras")
177 }
178}
179
180#[async_trait]
181impl ModelSession for CerebrasSession {
182 type Turn = CerebrasTurn;
183
184 async fn begin_turn(
185 &mut self,
186 turn_request: TurnRequest,
187 cancellation: Option<TurnCancellation>,
188 ) -> Result<CerebrasTurn, LoopError> {
189 let config = self.config.clone();
190 let rate_limit_slot = self.rate_limit_slot.clone();
191
192 let request_future = async move {
193 let built = request::build_chat_body(&config, &turn_request)
194 .map_err(|e| LoopError::Provider(e.to_string()))?;
195
196 let url = format!("{}/chat/completions", config.base_url);
197 let mut http = self.client.post(&url).bearer_auth(&config.api_key);
198
199 #[cfg(feature = "compression")]
200 let (body_bytes, content_type, content_encoding) = match &config.compression {
201 Some(cfg) => {
202 let encoded = crate::compression::encode_body(&built.body, cfg)
203 .map_err(LoopError::Provider)?;
204 (
205 bytes::Bytes::from(encoded.body),
206 encoded.content_type,
207 encoded.content_encoding,
208 )
209 }
210 None => (
211 bytes::Bytes::from(
212 serde_json::to_vec(&built.body)
213 .map_err(|e| LoopError::Provider(format!("json serialize: {e}")))?,
214 ),
215 "application/json",
216 None,
217 ),
218 };
219 #[cfg(not(feature = "compression"))]
220 let (body_bytes, content_type, content_encoding) = (
221 bytes::Bytes::from(
222 serde_json::to_vec(&built.body)
223 .map_err(|e| LoopError::Provider(format!("json serialize: {e}")))?,
224 ),
225 "application/json",
226 None::<&'static str>,
227 );
228
229 http = http.header("Content-Type", content_type);
230 if let Some(enc) = content_encoding {
231 http = http.header("Content-Encoding", enc);
232 }
233 if let Some(patch) = config.version_patch {
234 http = http.header(
235 crate::version::VERSION_PATCH_HEADER,
236 crate::version::format_version_patch(patch),
237 );
238 }
239 for (k, v) in &built.extra_headers {
240 http = http.header(*k, v.clone());
241 }
242 http = http.header(
243 "User-Agent",
244 concat!("agentkit-provider-cerebras/", env!("CARGO_PKG_VERSION")),
245 );
246 if config.streaming {
247 http = http.header("Accept", "text/event-stream");
248 }
249 for (k, v) in &config.extra_headers {
250 http = http.header(k.as_str(), v.as_str());
251 }
252 http = attach_body(http, body_bytes);
253
254 let response = http.send().await.map_err(|error| {
255 LoopError::Provider(format!("Cerebras request failed: {error}"))
256 })?;
257
258 {
259 let snap = RateLimitSnapshot::from_headers(response.headers());
260 if let Ok(mut slot) = rate_limit_slot.lock() {
261 *slot = Some(snap);
262 }
263 }
264
265 let status = response.status();
266 if !status.is_success() {
267 let body_text = response.text().await.unwrap_or_default();
268 return Err(LoopError::Provider(format!(
269 "Cerebras request failed with status {status}: {body_text}"
270 )));
271 }
272
273 if config.streaming {
274 Ok(CerebrasTurn {
275 inner: TurnInner::Streaming(Box::new(StreamingState {
276 body: response.bytes_stream(),
277 decoder: SseDecoder::new(),
278 translator: EventTranslator::new(),
279 pending: VecDeque::new(),
280 eof: false,
281 })),
282 })
283 } else {
284 let body_text = response.text().await.map_err(|error| {
285 LoopError::Provider(format!("failed to read Cerebras response body: {error}"))
286 })?;
287
288 let events = response::build_turn_from_response(&body_text)
289 .map_err(|e| LoopError::Provider(e.to_string()))?;
290 Ok(CerebrasTurn {
291 inner: TurnInner::Buffered { events },
292 })
293 }
294 };
295
296 if let Some(cancellation) = cancellation {
297 futures_util::pin_mut!(request_future);
298 let cancelled = cancellation.cancelled();
299 futures_util::pin_mut!(cancelled);
300 match select(request_future, cancelled).await {
301 Either::Left((result, _)) => result,
302 Either::Right((_, _)) => Err(LoopError::Cancelled),
303 }
304 } else {
305 request_future.await
306 }
307 }
308
309 fn model_name(&self) -> Option<&str> {
310 Some(&self.config.model)
311 }
312}
313
314fn attach_body(builder: HttpRequestBuilder, body: bytes::Bytes) -> HttpRequestBuilder {
315 builder.body(body)
316}
317
318#[async_trait]
319impl ModelTurn for CerebrasTurn {
320 async fn next_event(
321 &mut self,
322 cancellation: Option<TurnCancellation>,
323 ) -> Result<Option<ModelTurnEvent>, LoopError> {
324 if cancellation
325 .as_ref()
326 .is_some_and(TurnCancellation::is_cancelled)
327 {
328 return Err(LoopError::Cancelled);
329 }
330 match &mut self.inner {
331 TurnInner::Buffered { events } => Ok(events.pop_front()),
332 TurnInner::Streaming(state) => {
333 let StreamingState {
334 body,
335 decoder,
336 translator,
337 pending,
338 eof,
339 } = state.as_mut();
340 next_streaming_event(body, decoder, translator, pending, eof, cancellation).await
341 }
342 }
343 }
344}
345
346async fn next_streaming_event(
347 body: &mut BodyStream,
348 decoder: &mut SseDecoder,
349 translator: &mut EventTranslator,
350 pending: &mut VecDeque<ModelTurnEvent>,
351 eof: &mut bool,
352 cancellation: Option<TurnCancellation>,
353) -> Result<Option<ModelTurnEvent>, LoopError> {
354 loop {
355 if let Some(event) = pending.pop_front() {
356 return Ok(Some(event));
357 }
358 if *eof || translator.is_done() {
359 return Ok(None);
360 }
361
362 let chunk = if let Some(cancellation) = cancellation.as_ref() {
363 let next = body.next();
364 futures_util::pin_mut!(next);
365 let cancelled = cancellation.cancelled();
366 futures_util::pin_mut!(cancelled);
367 match select(next, cancelled).await {
368 Either::Left((chunk, _)) => chunk,
369 Either::Right((_, _)) => return Err(LoopError::Cancelled),
370 }
371 } else {
372 body.next().await
373 };
374
375 match chunk {
376 Some(Ok(bytes)) => {
377 let text = std::str::from_utf8(&bytes).map_err(|e| {
378 LoopError::Provider(format!("invalid UTF-8 in Cerebras stream: {e}"))
379 })?;
380 for sse in decoder.feed(text) {
381 match translator.handle(&sse) {
382 Ok(produced) => {
383 for ev in produced {
384 pending.push_back(ev);
385 }
386 }
387 Err(e) => return Err(LoopError::Provider(e.to_string())),
388 }
389 }
390 }
391 Some(Err(e)) => {
392 return Err(LoopError::Provider(format!(
393 "Cerebras stream body error: {e}"
394 )));
395 }
396 None => {
397 *eof = true;
398 }
399 }
400 }
401}
402
403#[cfg(test)]
404mod tests {
405 use super::*;
406 use agentkit_core::{CancellationController, FinishReason};
407 use agentkit_http::HttpError;
408 use bytes::Bytes;
409 use futures_util::stream;
410
411 fn streaming_turn_from(chunks: Vec<&'static str>) -> CerebrasTurn {
412 let body: BodyStream = Box::pin(stream::iter(
413 chunks
414 .into_iter()
415 .map(|c| Ok::<_, HttpError>(Bytes::from_static(c.as_bytes()))),
416 ));
417 CerebrasTurn {
418 inner: TurnInner::Streaming(Box::new(StreamingState {
419 body,
420 decoder: SseDecoder::new(),
421 translator: EventTranslator::new(),
422 pending: VecDeque::new(),
423 eof: false,
424 })),
425 }
426 }
427
428 #[tokio::test(flavor = "current_thread")]
429 async fn streaming_turn_drains_to_finished() {
430 let chunks = vec![
431 "data: {\"id\":\"m\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"hi\"}}]}\n\n",
432 "data: {\"choices\":[{\"index\":0,\"delta\":{},\"finish_reason\":\"done\"}],\"usage\":{\"prompt_tokens\":1,\"completion_tokens\":1}}\n\n",
433 "data: [DONE]\n\n",
434 ];
435 let mut turn = streaming_turn_from(chunks);
436 let mut saw_finished = false;
437 while let Some(event) = turn.next_event(None).await.expect("next_event") {
438 if let ModelTurnEvent::Finished(result) = event {
439 assert_eq!(result.finish_reason, FinishReason::Completed);
440 saw_finished = true;
441 }
442 }
443 assert!(saw_finished);
444 }
445
446 #[tokio::test(flavor = "current_thread")]
447 async fn streaming_turn_respects_pre_fired_cancellation() {
448 let chunks = vec!["data: {\"id\":\"m\",\"choices\":[]}\n\n"];
449 let mut turn = streaming_turn_from(chunks);
450 let controller = CancellationController::new();
451 let checkpoint = TurnCancellation::new(controller.handle());
452 controller.interrupt();
453 let err = turn.next_event(Some(checkpoint)).await.unwrap_err();
454 assert!(matches!(err, LoopError::Cancelled));
455 }
456}