Skip to main content

agentkit_provider_cerebras/
lib.rs

1//! Cerebras Inference API adapter for the agentkit agent loop.
2//!
3//! This crate implements the agentkit [`ModelAdapter`] directly against
4//! Cerebras' `/v1/chat/completions` endpoint.
5//!
6//! Streaming is on by default. Toggle via [`CerebrasConfig::with_streaming`].
7//!
8//! # Quick start
9//!
10//! ```rust,ignore
11//! use agentkit_loop::{Agent, SessionConfig};
12//! use agentkit_provider_cerebras::{CerebrasAdapter, CerebrasConfig};
13//!
14//! #[tokio::main]
15//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
16//!     let config = CerebrasConfig::from_env()?;
17//!     let adapter = CerebrasAdapter::new(config)?;
18//!     let agent = Agent::builder().model(adapter).build()?;
19//!     let _driver = agent.start(SessionConfig::new("demo")).await?;
20//!     Ok(())
21//! }
22//! ```
23
24pub mod config;
25pub mod error;
26pub mod models;
27pub mod rate_limit;
28pub mod request;
29pub mod response;
30pub mod version;
31
32#[cfg(feature = "batch")]
33pub mod batch;
34#[cfg(feature = "compression")]
35pub mod compression;
36#[cfg(feature = "batch")]
37pub mod files;
38
39mod sse;
40mod stream;
41
42use std::collections::VecDeque;
43use std::sync::{Arc, Mutex};
44
45use agentkit_core::TurnCancellation;
46use agentkit_http::{BodyStream, Http, HttpError, HttpRequestBuilder};
47use agentkit_loop::{
48    LoopError, ModelAdapter, ModelSession, ModelTurn, ModelTurnEvent, SessionConfig, TurnRequest,
49};
50use async_trait::async_trait;
51use futures_util::StreamExt;
52use futures_util::future::{Either, select};
53
54pub use crate::config::{
55    CerebrasConfig, DEFAULT_BASE_URL, DEFAULT_VERSION_PATCH, OutputFormat, PartKindName,
56    ReasoningConfig, ReasoningEffort, ReasoningFormat, ToolChoice,
57};
58pub use crate::error::{BuildError, CerebrasError, ResponseError};
59pub use crate::models::{ModelObject, ModelsClient};
60pub use crate::rate_limit::RateLimitSnapshot;
61
62#[cfg(feature = "predicted-outputs")]
63pub use crate::config::Prediction;
64#[cfg(feature = "compression")]
65pub use crate::config::{CompressionConfig, RequestEncoding};
66#[cfg(feature = "service-tiers")]
67pub use crate::config::{QueueThreshold, ServiceTier};
68
69#[cfg(feature = "batch")]
70pub use crate::batch::{
71    BatchClient, BatchItem, BatchJob, BatchOutcome, BatchRequestCounts, BatchStatus, ChatOverrides,
72};
73#[cfg(feature = "batch")]
74pub use crate::files::{FileObject, FilePurpose, FilesClient};
75
76use crate::stream::{EventTranslator, SseDecoder};
77
78/// Model adapter that connects the agentkit agent loop to Cerebras'
79/// `/v1/chat/completions` endpoint.
80#[derive(Clone)]
81pub struct CerebrasAdapter {
82    client: Http,
83    config: Arc<CerebrasConfig>,
84    last_rate_limit: Arc<Mutex<Option<RateLimitSnapshot>>>,
85}
86
87impl CerebrasAdapter {
88    /// Creates a new adapter from the given configuration, building a default
89    /// reqwest-backed HTTP client.
90    pub fn new(config: CerebrasConfig) -> Result<Self, CerebrasError> {
91        config.validate()?;
92        let client = reqwest::Client::builder()
93            .build()
94            .map(Http::new)
95            .map_err(|error| CerebrasError::Http(HttpError::request(error)))?;
96        Ok(Self {
97            client,
98            config: Arc::new(config),
99            last_rate_limit: Arc::new(Mutex::new(None)),
100        })
101    }
102
103    /// Creates a new adapter using a pre-configured [`Http`] client.
104    pub fn with_client(config: CerebrasConfig, client: Http) -> Result<Self, CerebrasError> {
105        config.validate()?;
106        Ok(Self {
107            client,
108            config: Arc::new(config),
109            last_rate_limit: Arc::new(Mutex::new(None)),
110        })
111    }
112
113    /// Reads the latest rate-limit snapshot, if any response has been received.
114    pub fn last_rate_limit(&self) -> Option<RateLimitSnapshot> {
115        self.last_rate_limit.lock().ok()?.clone()
116    }
117
118    /// Returns a typed client over `/v1/models`.
119    pub fn models(&self) -> ModelsClient<'_> {
120        ModelsClient::new(&self.client, self.config.clone())
121    }
122
123    /// Returns a typed client over the Batch API.
124    #[cfg(feature = "batch")]
125    pub fn batches(&self) -> BatchClient<'_> {
126        BatchClient::new(&self.client, self.config.clone())
127    }
128
129    /// Returns a typed client over the Files API.
130    #[cfg(feature = "batch")]
131    pub fn files(&self) -> FilesClient<'_> {
132        FilesClient::new(&self.client, self.config.clone())
133    }
134}
135
136/// An active session against the Cerebras chat-completions endpoint.
137pub struct CerebrasSession {
138    client: Http,
139    config: Arc<CerebrasConfig>,
140    rate_limit_slot: Arc<Mutex<Option<RateLimitSnapshot>>>,
141    _session_config: SessionConfig,
142}
143
144/// A single Cerebras chat-completions turn in progress.
145pub struct CerebrasTurn {
146    inner: TurnInner,
147}
148
149enum TurnInner {
150    Buffered { events: VecDeque<ModelTurnEvent> },
151    Streaming(Box<StreamingState>),
152}
153
154struct StreamingState {
155    body: BodyStream,
156    decoder: SseDecoder,
157    translator: EventTranslator,
158    pending: VecDeque<ModelTurnEvent>,
159    eof: bool,
160}
161
162#[async_trait]
163impl ModelAdapter for CerebrasAdapter {
164    type Session = CerebrasSession;
165
166    async fn start_session(&self, config: SessionConfig) -> Result<Self::Session, LoopError> {
167        Ok(CerebrasSession {
168            client: self.client.clone(),
169            config: self.config.clone(),
170            rate_limit_slot: self.last_rate_limit.clone(),
171            _session_config: config,
172        })
173    }
174
175    fn provider_name(&self) -> Option<&str> {
176        Some("cerebras")
177    }
178}
179
180#[async_trait]
181impl ModelSession for CerebrasSession {
182    type Turn = CerebrasTurn;
183
184    async fn begin_turn(
185        &mut self,
186        turn_request: TurnRequest,
187        cancellation: Option<TurnCancellation>,
188    ) -> Result<CerebrasTurn, LoopError> {
189        let config = self.config.clone();
190        let rate_limit_slot = self.rate_limit_slot.clone();
191
192        let request_future = async move {
193            let built = request::build_chat_body(&config, &turn_request)
194                .map_err(|e| LoopError::Provider(e.to_string()))?;
195
196            let url = format!("{}/chat/completions", config.base_url);
197            let mut http = self.client.post(&url).bearer_auth(&config.api_key);
198
199            #[cfg(feature = "compression")]
200            let (body_bytes, content_type, content_encoding) = match &config.compression {
201                Some(cfg) => {
202                    let encoded = crate::compression::encode_body(&built.body, cfg)
203                        .map_err(LoopError::Provider)?;
204                    (
205                        bytes::Bytes::from(encoded.body),
206                        encoded.content_type,
207                        encoded.content_encoding,
208                    )
209                }
210                None => (
211                    bytes::Bytes::from(
212                        serde_json::to_vec(&built.body)
213                            .map_err(|e| LoopError::Provider(format!("json serialize: {e}")))?,
214                    ),
215                    "application/json",
216                    None,
217                ),
218            };
219            #[cfg(not(feature = "compression"))]
220            let (body_bytes, content_type, content_encoding) = (
221                bytes::Bytes::from(
222                    serde_json::to_vec(&built.body)
223                        .map_err(|e| LoopError::Provider(format!("json serialize: {e}")))?,
224                ),
225                "application/json",
226                None::<&'static str>,
227            );
228
229            http = http.header("Content-Type", content_type);
230            if let Some(enc) = content_encoding {
231                http = http.header("Content-Encoding", enc);
232            }
233            if let Some(patch) = config.version_patch {
234                http = http.header(
235                    crate::version::VERSION_PATCH_HEADER,
236                    crate::version::format_version_patch(patch),
237                );
238            }
239            for (k, v) in &built.extra_headers {
240                http = http.header(*k, v.clone());
241            }
242            http = http.header(
243                "User-Agent",
244                concat!("agentkit-provider-cerebras/", env!("CARGO_PKG_VERSION")),
245            );
246            if config.streaming {
247                http = http.header("Accept", "text/event-stream");
248            }
249            for (k, v) in &config.extra_headers {
250                http = http.header(k.as_str(), v.as_str());
251            }
252            http = attach_body(http, body_bytes);
253
254            let response = http.send().await.map_err(|error| {
255                LoopError::Provider(format!("Cerebras request failed: {error}"))
256            })?;
257
258            {
259                let snap = RateLimitSnapshot::from_headers(response.headers());
260                if let Ok(mut slot) = rate_limit_slot.lock() {
261                    *slot = Some(snap);
262                }
263            }
264
265            let status = response.status();
266            if !status.is_success() {
267                let body_text = response.text().await.unwrap_or_default();
268                return Err(LoopError::Provider(format!(
269                    "Cerebras request failed with status {status}: {body_text}"
270                )));
271            }
272
273            if config.streaming {
274                Ok(CerebrasTurn {
275                    inner: TurnInner::Streaming(Box::new(StreamingState {
276                        body: response.bytes_stream(),
277                        decoder: SseDecoder::new(),
278                        translator: EventTranslator::new(),
279                        pending: VecDeque::new(),
280                        eof: false,
281                    })),
282                })
283            } else {
284                let body_text = response.text().await.map_err(|error| {
285                    LoopError::Provider(format!("failed to read Cerebras response body: {error}"))
286                })?;
287
288                let events = response::build_turn_from_response(&body_text)
289                    .map_err(|e| LoopError::Provider(e.to_string()))?;
290                Ok(CerebrasTurn {
291                    inner: TurnInner::Buffered { events },
292                })
293            }
294        };
295
296        if let Some(cancellation) = cancellation {
297            futures_util::pin_mut!(request_future);
298            let cancelled = cancellation.cancelled();
299            futures_util::pin_mut!(cancelled);
300            match select(request_future, cancelled).await {
301                Either::Left((result, _)) => result,
302                Either::Right((_, _)) => Err(LoopError::Cancelled),
303            }
304        } else {
305            request_future.await
306        }
307    }
308
309    fn model_name(&self) -> Option<&str> {
310        Some(&self.config.model)
311    }
312}
313
314fn attach_body(builder: HttpRequestBuilder, body: bytes::Bytes) -> HttpRequestBuilder {
315    builder.body(body)
316}
317
318#[async_trait]
319impl ModelTurn for CerebrasTurn {
320    async fn next_event(
321        &mut self,
322        cancellation: Option<TurnCancellation>,
323    ) -> Result<Option<ModelTurnEvent>, LoopError> {
324        if cancellation
325            .as_ref()
326            .is_some_and(TurnCancellation::is_cancelled)
327        {
328            return Err(LoopError::Cancelled);
329        }
330        match &mut self.inner {
331            TurnInner::Buffered { events } => Ok(events.pop_front()),
332            TurnInner::Streaming(state) => {
333                let StreamingState {
334                    body,
335                    decoder,
336                    translator,
337                    pending,
338                    eof,
339                } = state.as_mut();
340                next_streaming_event(body, decoder, translator, pending, eof, cancellation).await
341            }
342        }
343    }
344}
345
346async fn next_streaming_event(
347    body: &mut BodyStream,
348    decoder: &mut SseDecoder,
349    translator: &mut EventTranslator,
350    pending: &mut VecDeque<ModelTurnEvent>,
351    eof: &mut bool,
352    cancellation: Option<TurnCancellation>,
353) -> Result<Option<ModelTurnEvent>, LoopError> {
354    loop {
355        if let Some(event) = pending.pop_front() {
356            return Ok(Some(event));
357        }
358        if *eof || translator.is_done() {
359            return Ok(None);
360        }
361
362        let chunk = if let Some(cancellation) = cancellation.as_ref() {
363            let next = body.next();
364            futures_util::pin_mut!(next);
365            let cancelled = cancellation.cancelled();
366            futures_util::pin_mut!(cancelled);
367            match select(next, cancelled).await {
368                Either::Left((chunk, _)) => chunk,
369                Either::Right((_, _)) => return Err(LoopError::Cancelled),
370            }
371        } else {
372            body.next().await
373        };
374
375        match chunk {
376            Some(Ok(bytes)) => {
377                let text = std::str::from_utf8(&bytes).map_err(|e| {
378                    LoopError::Provider(format!("invalid UTF-8 in Cerebras stream: {e}"))
379                })?;
380                for sse in decoder.feed(text) {
381                    match translator.handle(&sse) {
382                        Ok(produced) => {
383                            for ev in produced {
384                                pending.push_back(ev);
385                            }
386                        }
387                        Err(e) => return Err(LoopError::Provider(e.to_string())),
388                    }
389                }
390            }
391            Some(Err(e)) => {
392                return Err(LoopError::Provider(format!(
393                    "Cerebras stream body error: {e}"
394                )));
395            }
396            None => {
397                *eof = true;
398            }
399        }
400    }
401}
402
403#[cfg(test)]
404mod tests {
405    use super::*;
406    use agentkit_core::{CancellationController, FinishReason};
407    use agentkit_http::HttpError;
408    use bytes::Bytes;
409    use futures_util::stream;
410
411    fn streaming_turn_from(chunks: Vec<&'static str>) -> CerebrasTurn {
412        let body: BodyStream = Box::pin(stream::iter(
413            chunks
414                .into_iter()
415                .map(|c| Ok::<_, HttpError>(Bytes::from_static(c.as_bytes()))),
416        ));
417        CerebrasTurn {
418            inner: TurnInner::Streaming(Box::new(StreamingState {
419                body,
420                decoder: SseDecoder::new(),
421                translator: EventTranslator::new(),
422                pending: VecDeque::new(),
423                eof: false,
424            })),
425        }
426    }
427
428    #[tokio::test(flavor = "current_thread")]
429    async fn streaming_turn_drains_to_finished() {
430        let chunks = vec![
431            "data: {\"id\":\"m\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"hi\"}}]}\n\n",
432            "data: {\"choices\":[{\"index\":0,\"delta\":{},\"finish_reason\":\"done\"}],\"usage\":{\"prompt_tokens\":1,\"completion_tokens\":1}}\n\n",
433            "data: [DONE]\n\n",
434        ];
435        let mut turn = streaming_turn_from(chunks);
436        let mut saw_finished = false;
437        while let Some(event) = turn.next_event(None).await.expect("next_event") {
438            if let ModelTurnEvent::Finished(result) = event {
439                assert_eq!(result.finish_reason, FinishReason::Completed);
440                saw_finished = true;
441            }
442        }
443        assert!(saw_finished);
444    }
445
446    #[tokio::test(flavor = "current_thread")]
447    async fn streaming_turn_respects_pre_fired_cancellation() {
448        let chunks = vec!["data: {\"id\":\"m\",\"choices\":[]}\n\n"];
449        let mut turn = streaming_turn_from(chunks);
450        let controller = CancellationController::new();
451        let checkpoint = TurnCancellation::new(controller.handle());
452        controller.interrupt();
453        let err = turn.next_event(Some(checkpoint)).await.unwrap_err();
454        assert!(matches!(err, LoopError::Cancelled));
455    }
456}