Skip to main content

agentkit_adapter_completions/
lib.rs

1//! Generic OpenAI-compatible chat completions adapter for agentkit.
2//!
3//! This crate provides the [`CompletionsProvider`] trait and a generic
4//! [`CompletionsAdapter`] that handles all common chat completions logic:
5//! transcript conversion, request building, response parsing, tool call
6//! extraction, usage mapping, cancellation, and multimodal content.
7//!
8//! Provider crates (OpenRouter, OpenAI, Ollama, etc.) implement
9//! [`CompletionsProvider`] to supply authentication, endpoint URLs, and
10//! provider-specific hooks. The adapter does the rest.
11//!
12//! # Example
13//!
14//! ```rust,ignore
15//! use agentkit_adapter_completions::{CompletionsAdapter, CompletionsProvider};
16//!
17//! let adapter = CompletionsAdapter::new(my_provider)?;
18//! let agent = Agent::builder().model(adapter).build()?;
19//! ```
20
21mod error;
22mod media;
23mod request;
24mod response;
25mod sse;
26mod stream;
27
28use std::collections::VecDeque;
29use std::sync::Arc;
30
31use agentkit_core::{MetadataMap, TurnCancellation, Usage};
32use agentkit_http::{BodyStream, Http, HttpError, HttpRequestBuilder, StatusCode};
33use agentkit_loop::{
34    LoopError, ModelAdapter, ModelSession, ModelTurn, ModelTurnEvent, SessionConfig, TurnRequest,
35};
36use async_trait::async_trait;
37use futures_util::StreamExt;
38use futures_util::future::{Either, select};
39use serde::Serialize;
40use serde_json::Value;
41
42pub use crate::error::CompletionsError;
43use crate::stream::{EventTranslator, PostprocessResponse, SseDecoder};
44
45/// Trait implemented by each provider to customise the generic chat completions adapter.
46///
47/// The associated [`Config`](CompletionsProvider::Config) type allows each provider
48/// to define a strongly-typed struct with the exact request parameters it supports.
49/// The adapter serialises it and merges it into the request body.
50///
51/// # Required methods
52///
53/// - [`provider_name`](CompletionsProvider::provider_name) — for error messages
54/// - [`endpoint_url`](CompletionsProvider::endpoint_url) — the chat completions URL
55/// - [`config`](CompletionsProvider::config) — returns the request configuration
56///
57/// # Hooks
58///
59/// All have default implementations that pass through unchanged:
60///
61/// - [`preprocess_request`](CompletionsProvider::preprocess_request) — add auth headers, custom user-agent, etc.
62/// - [`apply_prompt_cache`](CompletionsProvider::apply_prompt_cache) — map normalized cache requests into provider request fields
63/// - [`preprocess_response`](CompletionsProvider::preprocess_response) — inspect/reject raw response before parsing
64/// - [`postprocess_response`](CompletionsProvider::postprocess_response) — enrich parsed usage/metadata from raw response
65pub trait CompletionsProvider: Send + Sync + Clone {
66    /// Strongly-typed request configuration (model, temperature, top_p, etc.).
67    ///
68    /// Serialised via `serde_json::to_value` and merged into the request body.
69    /// Use `#[serde(skip_serializing_if = "Option::is_none")]` on optional fields
70    /// to avoid sending `null` values.
71    type Config: Serialize + Clone + Send + Sync;
72
73    /// Provider name for error messages (e.g. "OpenRouter", "Ollama").
74    fn provider_name(&self) -> &str;
75
76    /// The chat completions endpoint URL.
77    fn endpoint_url(&self) -> &str;
78
79    /// Returns the request configuration to merge into the body.
80    fn config(&self) -> &Self::Config;
81
82    /// Hook to modify the HTTP request before it is sent.
83    ///
84    /// Use this to add authentication headers, set a custom user-agent,
85    /// or apply any other request-level customisation.
86    ///
87    /// The default implementation passes the builder through unchanged.
88    fn preprocess_request(&self, builder: HttpRequestBuilder) -> HttpRequestBuilder {
89        builder
90    }
91
92    /// Hook to map a normalized prompt cache request into the provider's JSON
93    /// request body.
94    ///
95    /// Called after the adapter has constructed the standard chat-completions
96    /// payload. Providers can inspect [`TurnRequest::cache`] and mutate the
97    /// request body accordingly.
98    fn apply_prompt_cache(
99        &self,
100        _body: &mut serde_json::Map<String, Value>,
101        _request: &TurnRequest,
102    ) -> Result<(), LoopError> {
103        Ok(())
104    }
105
106    /// Whether to request an SSE streaming response. Defaults to `true`.
107    fn streaming(&self) -> bool {
108        true
109    }
110
111    /// Hook to add provider-specific streaming options to the JSON request.
112    ///
113    /// Providers that support terminal usage frames can insert fields such as
114    /// `stream_options`; the default leaves the request unchanged.
115    fn apply_stream_options(
116        &self,
117        _body: &mut serde_json::Map<String, Value>,
118    ) -> Result<(), LoopError> {
119        Ok(())
120    }
121
122    /// Whether the upstream chat template enforces strict
123    /// `user`/`assistant` role alternation.
124    ///
125    /// When `true`, the adapter merges adjacent `user`-role messages
126    /// (including notifications and tool-result follow-ups that come back
127    /// as user messages) into a single message before sending. Required
128    /// for vLLM-served Mistral templates and the Mistral hosted API; see
129    /// <https://github.com/vllm-project/vllm/issues/6862>.
130    ///
131    /// Defaults to `false`. Providers that target strictly-alternating
132    /// upstreams should override.
133    fn requires_alternating_roles(&self) -> bool {
134        false
135    }
136
137    /// Hook to inspect the raw HTTP response before deserialisation.
138    ///
139    /// Called after the response body is read but before it is parsed into
140    /// the chat completion response struct. Return `Err` to reject the
141    /// response (e.g. for providers that return HTTP 200 with an error payload).
142    ///
143    /// The default implementation does nothing.
144    fn preprocess_response(&self, _status: StatusCode, _body: &str) -> Result<(), LoopError> {
145        Ok(())
146    }
147
148    /// Hook to enrich parsed response data with provider-specific fields.
149    ///
150    /// Called after the standard response parsing is complete. The provider
151    /// can read extra fields from the raw JSON (e.g. `cost` in the usage
152    /// object, `model` or `refusal` in the response) and fold them into
153    /// the `Usage` and `MetadataMap` that will be attached to the output items.
154    ///
155    /// The default implementation does nothing.
156    fn postprocess_response(
157        &self,
158        _usage: &mut Option<Usage>,
159        _metadata: &mut MetadataMap,
160        _raw_response: &Value,
161    ) {
162    }
163}
164
165/// Generic chat completions adapter, parameterised by a [`CompletionsProvider`].
166///
167/// Implements [`ModelAdapter`] so it can be passed to
168/// [`Agent::builder().model()`](agentkit_loop::Agent::builder).
169#[derive(Clone)]
170pub struct CompletionsAdapter<P: CompletionsProvider> {
171    client: Http,
172    provider: Arc<P>,
173    /// Lowercase provider name stamped onto telemetry spans as the
174    /// `gen_ai.provider.name` attribute from the OTel GenAI semantic
175    /// conventions.
176    provider_label: String,
177}
178
179impl<P: CompletionsProvider> CompletionsAdapter<P> {
180    /// Creates a new adapter from the given provider.
181    ///
182    /// Builds a default reqwest-backed HTTP client reused for all sessions and turns.
183    pub fn new(provider: P) -> Result<Self, CompletionsError> {
184        let client = reqwest::Client::builder()
185            .build()
186            .map(Http::new)
187            .map_err(|error| CompletionsError::HttpClient(HttpError::request(error)))?;
188
189        Ok(Self {
190            client,
191            provider_label: provider.provider_name().to_lowercase(),
192            provider: Arc::new(provider),
193        })
194    }
195
196    /// Creates a new adapter with a pre-configured [`Http`] client. Use this to
197    /// attach auth headers via `default_headers`, supply custom TLS/proxies,
198    /// or plug in a non-reqwest backend.
199    pub fn with_client(provider: P, client: Http) -> Self {
200        Self {
201            client,
202            provider_label: provider.provider_name().to_lowercase(),
203            provider: Arc::new(provider),
204        }
205    }
206}
207
208/// An active session with a chat completions provider.
209///
210/// Created by [`CompletionsAdapter::start_session`](ModelAdapter::start_session).
211pub struct CompletionsSession<P: CompletionsProvider> {
212    client: Http,
213    provider: Arc<P>,
214    model: Option<String>,
215    _session_config: SessionConfig,
216}
217
218/// A turn from a chat completion response.
219pub struct CompletionsTurn {
220    inner: TurnInner,
221}
222
223enum TurnInner {
224    Buffered { events: VecDeque<ModelTurnEvent> },
225    Streaming(Box<StreamingState>),
226}
227
228struct StreamingState {
229    body: BodyStream,
230    decoder: SseDecoder,
231    translator: EventTranslator,
232    pending: VecDeque<ModelTurnEvent>,
233    eof: bool,
234    postprocess: PostprocessResponse,
235}
236
237impl CompletionsTurn {
238    fn buffered(events: VecDeque<ModelTurnEvent>) -> Self {
239        Self {
240            inner: TurnInner::Buffered { events },
241        }
242    }
243
244    fn streaming(body: BodyStream, postprocess: PostprocessResponse) -> Self {
245        Self {
246            inner: TurnInner::Streaming(Box::new(StreamingState {
247                body,
248                decoder: SseDecoder::new(),
249                translator: EventTranslator::new(),
250                pending: VecDeque::new(),
251                eof: false,
252                postprocess,
253            })),
254        }
255    }
256}
257
258#[async_trait]
259impl<P: CompletionsProvider + 'static> ModelAdapter for CompletionsAdapter<P> {
260    type Session = CompletionsSession<P>;
261
262    async fn start_session(&self, config: SessionConfig) -> Result<Self::Session, LoopError> {
263        // The provider's typed request config is opaque to the adapter; the
264        // serialized "model" key is the chat-completions contract, so pull
265        // the telemetry model name from there.
266        let model = serde_json::to_value(self.provider.config())
267            .ok()
268            .and_then(|config| {
269                config
270                    .get("model")
271                    .and_then(Value::as_str)
272                    .map(str::to_owned)
273            });
274        Ok(CompletionsSession {
275            client: self.client.clone(),
276            provider: self.provider.clone(),
277            model,
278            _session_config: config,
279        })
280    }
281
282    fn provider_name(&self) -> Option<&str> {
283        Some(&self.provider_label)
284    }
285}
286
287#[async_trait]
288impl<P: CompletionsProvider + 'static> ModelSession for CompletionsSession<P> {
289    type Turn = CompletionsTurn;
290
291    async fn begin_turn(
292        &mut self,
293        turn_request: TurnRequest,
294        cancellation: Option<TurnCancellation>,
295    ) -> Result<CompletionsTurn, LoopError> {
296        let provider = self.provider.clone();
297        let provider_name = provider.provider_name().to_owned();
298
299        let request_future = async {
300            let body = request::build_request_body(provider.as_ref(), &turn_request)
301                .map_err(|e| LoopError::Provider(e.to_string()))?;
302
303            let http = self
304                .client
305                .post(provider.endpoint_url())
306                .header("Content-Type", "application/json");
307
308            let mut http = provider.preprocess_request(http);
309            if provider.streaming() {
310                http = http.header("Accept", "text/event-stream");
311            }
312
313            let response = http.json(&body).send().await.map_err(|error| {
314                LoopError::Provider(format!("{provider_name} request failed: {error}"))
315            })?;
316
317            let status = response.status();
318            if provider.streaming() && status.is_success() {
319                let provider_for_postprocess = provider.clone();
320                let postprocess: PostprocessResponse = Arc::new(move |usage, metadata, raw| {
321                    provider_for_postprocess.postprocess_response(usage, metadata, raw);
322                });
323                return Ok(CompletionsTurn::streaming(
324                    response.bytes_stream(),
325                    postprocess,
326                ));
327            }
328
329            let body = response.text().await.map_err(|error| {
330                LoopError::Provider(format!(
331                    "failed to read {provider_name} response body: {error}"
332                ))
333            })?;
334
335            provider.preprocess_response(status, &body)?;
336
337            if !status.is_success() {
338                return Err(LoopError::Provider(format!(
339                    "{provider_name} request failed with status {status}: {body}"
340                )));
341            }
342
343            let (events, _raw) = response::build_turn_from_response(provider.as_ref(), &body)
344                .map_err(|e| LoopError::Provider(e.to_string()))?;
345
346            Ok(CompletionsTurn::buffered(events))
347        };
348
349        if let Some(cancellation) = cancellation {
350            futures_util::pin_mut!(request_future);
351            let cancelled = cancellation.cancelled();
352            futures_util::pin_mut!(cancelled);
353            match select(request_future, cancelled).await {
354                Either::Left((result, _)) => result,
355                Either::Right((_, _)) => Err(LoopError::Cancelled),
356            }
357        } else {
358            request_future.await
359        }
360    }
361
362    fn model_name(&self) -> Option<&str> {
363        self.model.as_deref()
364    }
365}
366
367#[async_trait]
368impl ModelTurn for CompletionsTurn {
369    async fn next_event(
370        &mut self,
371        cancellation: Option<TurnCancellation>,
372    ) -> Result<Option<ModelTurnEvent>, LoopError> {
373        if cancellation
374            .as_ref()
375            .is_some_and(TurnCancellation::is_cancelled)
376        {
377            return Err(LoopError::Cancelled);
378        }
379        match &mut self.inner {
380            TurnInner::Buffered { events } => Ok(events.pop_front()),
381            TurnInner::Streaming(state) => {
382                let StreamingState {
383                    body,
384                    decoder,
385                    translator,
386                    pending,
387                    eof,
388                    postprocess,
389                } = state.as_mut();
390                next_streaming_event(
391                    body,
392                    decoder,
393                    translator,
394                    pending,
395                    eof,
396                    postprocess,
397                    cancellation,
398                )
399                .await
400            }
401        }
402    }
403}
404
405async fn next_streaming_event(
406    body: &mut BodyStream,
407    decoder: &mut SseDecoder,
408    translator: &mut EventTranslator,
409    pending: &mut VecDeque<ModelTurnEvent>,
410    eof: &mut bool,
411    postprocess: &PostprocessResponse,
412    cancellation: Option<TurnCancellation>,
413) -> Result<Option<ModelTurnEvent>, LoopError> {
414    loop {
415        if let Some(event) = pending.pop_front() {
416            return Ok(Some(event));
417        }
418        if *eof || translator.is_done() {
419            return Ok(None);
420        }
421
422        let chunk = if let Some(cancellation) = cancellation.as_ref() {
423            let next = body.next();
424            futures_util::pin_mut!(next);
425            let cancelled = cancellation.cancelled();
426            futures_util::pin_mut!(cancelled);
427            match select(next, cancelled).await {
428                Either::Left((chunk, _)) => chunk,
429                Either::Right((_, _)) => return Err(LoopError::Cancelled),
430            }
431        } else {
432            body.next().await
433        };
434
435        match chunk {
436            Some(Ok(bytes)) => {
437                let text = std::str::from_utf8(&bytes).map_err(|e| {
438                    LoopError::Provider(format!("invalid UTF-8 in completions stream: {e}"))
439                })?;
440                for sse in decoder.feed(text) {
441                    for event in translator
442                        .handle(&sse, postprocess)
443                        .map_err(|e| LoopError::Provider(e.to_string()))?
444                    {
445                        pending.push_back(event);
446                    }
447                }
448            }
449            Some(Err(e)) => {
450                return Err(LoopError::Provider(format!(
451                    "completions stream body error: {e}"
452                )));
453            }
454            None => {
455                *eof = true;
456            }
457        }
458    }
459}