agentkit_adapter_completions/
lib.rs1mod error;
22mod media;
23mod request;
24mod response;
25mod sse;
26mod stream;
27
28use std::collections::VecDeque;
29use std::sync::Arc;
30
31use agentkit_core::{MetadataMap, TurnCancellation, Usage};
32use agentkit_http::{BodyStream, Http, HttpError, HttpRequestBuilder, StatusCode};
33use agentkit_loop::{
34 LoopError, ModelAdapter, ModelSession, ModelTurn, ModelTurnEvent, SessionConfig, TurnRequest,
35};
36use async_trait::async_trait;
37use futures_util::StreamExt;
38use futures_util::future::{Either, select};
39use serde::Serialize;
40use serde_json::Value;
41
42pub use crate::error::CompletionsError;
43use crate::stream::{EventTranslator, PostprocessResponse, SseDecoder};
44
45pub trait CompletionsProvider: Send + Sync + Clone {
66 type Config: Serialize + Clone + Send + Sync;
72
73 fn provider_name(&self) -> &str;
75
76 fn endpoint_url(&self) -> &str;
78
79 fn config(&self) -> &Self::Config;
81
82 fn preprocess_request(&self, builder: HttpRequestBuilder) -> HttpRequestBuilder {
89 builder
90 }
91
92 fn apply_prompt_cache(
99 &self,
100 _body: &mut serde_json::Map<String, Value>,
101 _request: &TurnRequest,
102 ) -> Result<(), LoopError> {
103 Ok(())
104 }
105
106 fn streaming(&self) -> bool {
108 true
109 }
110
111 fn apply_stream_options(
116 &self,
117 _body: &mut serde_json::Map<String, Value>,
118 ) -> Result<(), LoopError> {
119 Ok(())
120 }
121
122 fn requires_alternating_roles(&self) -> bool {
134 false
135 }
136
137 fn preprocess_response(&self, _status: StatusCode, _body: &str) -> Result<(), LoopError> {
145 Ok(())
146 }
147
148 fn postprocess_response(
157 &self,
158 _usage: &mut Option<Usage>,
159 _metadata: &mut MetadataMap,
160 _raw_response: &Value,
161 ) {
162 }
163}
164
165#[derive(Clone)]
170pub struct CompletionsAdapter<P: CompletionsProvider> {
171 client: Http,
172 provider: Arc<P>,
173 provider_label: String,
177}
178
179impl<P: CompletionsProvider> CompletionsAdapter<P> {
180 pub fn new(provider: P) -> Result<Self, CompletionsError> {
184 let client = reqwest::Client::builder()
185 .build()
186 .map(Http::new)
187 .map_err(|error| CompletionsError::HttpClient(HttpError::request(error)))?;
188
189 Ok(Self {
190 client,
191 provider_label: provider.provider_name().to_lowercase(),
192 provider: Arc::new(provider),
193 })
194 }
195
196 pub fn with_client(provider: P, client: Http) -> Self {
200 Self {
201 client,
202 provider_label: provider.provider_name().to_lowercase(),
203 provider: Arc::new(provider),
204 }
205 }
206}
207
208pub struct CompletionsSession<P: CompletionsProvider> {
212 client: Http,
213 provider: Arc<P>,
214 model: Option<String>,
215 _session_config: SessionConfig,
216}
217
218pub struct CompletionsTurn {
220 inner: TurnInner,
221}
222
223enum TurnInner {
224 Buffered { events: VecDeque<ModelTurnEvent> },
225 Streaming(Box<StreamingState>),
226}
227
228struct StreamingState {
229 body: BodyStream,
230 decoder: SseDecoder,
231 translator: EventTranslator,
232 pending: VecDeque<ModelTurnEvent>,
233 eof: bool,
234 postprocess: PostprocessResponse,
235}
236
237impl CompletionsTurn {
238 fn buffered(events: VecDeque<ModelTurnEvent>) -> Self {
239 Self {
240 inner: TurnInner::Buffered { events },
241 }
242 }
243
244 fn streaming(body: BodyStream, postprocess: PostprocessResponse) -> Self {
245 Self {
246 inner: TurnInner::Streaming(Box::new(StreamingState {
247 body,
248 decoder: SseDecoder::new(),
249 translator: EventTranslator::new(),
250 pending: VecDeque::new(),
251 eof: false,
252 postprocess,
253 })),
254 }
255 }
256}
257
258#[async_trait]
259impl<P: CompletionsProvider + 'static> ModelAdapter for CompletionsAdapter<P> {
260 type Session = CompletionsSession<P>;
261
262 async fn start_session(&self, config: SessionConfig) -> Result<Self::Session, LoopError> {
263 let model = serde_json::to_value(self.provider.config())
267 .ok()
268 .and_then(|config| {
269 config
270 .get("model")
271 .and_then(Value::as_str)
272 .map(str::to_owned)
273 });
274 Ok(CompletionsSession {
275 client: self.client.clone(),
276 provider: self.provider.clone(),
277 model,
278 _session_config: config,
279 })
280 }
281
282 fn provider_name(&self) -> Option<&str> {
283 Some(&self.provider_label)
284 }
285}
286
287#[async_trait]
288impl<P: CompletionsProvider + 'static> ModelSession for CompletionsSession<P> {
289 type Turn = CompletionsTurn;
290
291 async fn begin_turn(
292 &mut self,
293 turn_request: TurnRequest,
294 cancellation: Option<TurnCancellation>,
295 ) -> Result<CompletionsTurn, LoopError> {
296 let provider = self.provider.clone();
297 let provider_name = provider.provider_name().to_owned();
298
299 let request_future = async {
300 let body = request::build_request_body(provider.as_ref(), &turn_request)
301 .map_err(|e| LoopError::Provider(e.to_string()))?;
302
303 let http = self
304 .client
305 .post(provider.endpoint_url())
306 .header("Content-Type", "application/json");
307
308 let mut http = provider.preprocess_request(http);
309 if provider.streaming() {
310 http = http.header("Accept", "text/event-stream");
311 }
312
313 let response = http.json(&body).send().await.map_err(|error| {
314 LoopError::Provider(format!("{provider_name} request failed: {error}"))
315 })?;
316
317 let status = response.status();
318 if provider.streaming() && status.is_success() {
319 let provider_for_postprocess = provider.clone();
320 let postprocess: PostprocessResponse = Arc::new(move |usage, metadata, raw| {
321 provider_for_postprocess.postprocess_response(usage, metadata, raw);
322 });
323 return Ok(CompletionsTurn::streaming(
324 response.bytes_stream(),
325 postprocess,
326 ));
327 }
328
329 let body = response.text().await.map_err(|error| {
330 LoopError::Provider(format!(
331 "failed to read {provider_name} response body: {error}"
332 ))
333 })?;
334
335 provider.preprocess_response(status, &body)?;
336
337 if !status.is_success() {
338 return Err(LoopError::Provider(format!(
339 "{provider_name} request failed with status {status}: {body}"
340 )));
341 }
342
343 let (events, _raw) = response::build_turn_from_response(provider.as_ref(), &body)
344 .map_err(|e| LoopError::Provider(e.to_string()))?;
345
346 Ok(CompletionsTurn::buffered(events))
347 };
348
349 if let Some(cancellation) = cancellation {
350 futures_util::pin_mut!(request_future);
351 let cancelled = cancellation.cancelled();
352 futures_util::pin_mut!(cancelled);
353 match select(request_future, cancelled).await {
354 Either::Left((result, _)) => result,
355 Either::Right((_, _)) => Err(LoopError::Cancelled),
356 }
357 } else {
358 request_future.await
359 }
360 }
361
362 fn model_name(&self) -> Option<&str> {
363 self.model.as_deref()
364 }
365}
366
367#[async_trait]
368impl ModelTurn for CompletionsTurn {
369 async fn next_event(
370 &mut self,
371 cancellation: Option<TurnCancellation>,
372 ) -> Result<Option<ModelTurnEvent>, LoopError> {
373 if cancellation
374 .as_ref()
375 .is_some_and(TurnCancellation::is_cancelled)
376 {
377 return Err(LoopError::Cancelled);
378 }
379 match &mut self.inner {
380 TurnInner::Buffered { events } => Ok(events.pop_front()),
381 TurnInner::Streaming(state) => {
382 let StreamingState {
383 body,
384 decoder,
385 translator,
386 pending,
387 eof,
388 postprocess,
389 } = state.as_mut();
390 next_streaming_event(
391 body,
392 decoder,
393 translator,
394 pending,
395 eof,
396 postprocess,
397 cancellation,
398 )
399 .await
400 }
401 }
402 }
403}
404
405async fn next_streaming_event(
406 body: &mut BodyStream,
407 decoder: &mut SseDecoder,
408 translator: &mut EventTranslator,
409 pending: &mut VecDeque<ModelTurnEvent>,
410 eof: &mut bool,
411 postprocess: &PostprocessResponse,
412 cancellation: Option<TurnCancellation>,
413) -> Result<Option<ModelTurnEvent>, LoopError> {
414 loop {
415 if let Some(event) = pending.pop_front() {
416 return Ok(Some(event));
417 }
418 if *eof || translator.is_done() {
419 return Ok(None);
420 }
421
422 let chunk = if let Some(cancellation) = cancellation.as_ref() {
423 let next = body.next();
424 futures_util::pin_mut!(next);
425 let cancelled = cancellation.cancelled();
426 futures_util::pin_mut!(cancelled);
427 match select(next, cancelled).await {
428 Either::Left((chunk, _)) => chunk,
429 Either::Right((_, _)) => return Err(LoopError::Cancelled),
430 }
431 } else {
432 body.next().await
433 };
434
435 match chunk {
436 Some(Ok(bytes)) => {
437 let text = std::str::from_utf8(&bytes).map_err(|e| {
438 LoopError::Provider(format!("invalid UTF-8 in completions stream: {e}"))
439 })?;
440 for sse in decoder.feed(text) {
441 for event in translator
442 .handle(&sse, postprocess)
443 .map_err(|e| LoopError::Provider(e.to_string()))?
444 {
445 pending.push_back(event);
446 }
447 }
448 }
449 Some(Err(e)) => {
450 return Err(LoopError::Provider(format!(
451 "completions stream body error: {e}"
452 )));
453 }
454 None => {
455 *eof = true;
456 }
457 }
458 }
459}