#[cfg(feature = "openai-compat")]
use std::pin::Pin;
use futures::stream::{self, Stream, StreamExt as _};
#[cfg(feature = "openai-compat")]
use tokio_util::sync::CancellationToken;
#[cfg(feature = "openai-compat")]
use tracing::debug;
use tracing::warn;
#[cfg(feature = "mistral")]
use serde::Serialize;
use swink_agent::{AgentContext, AssistantMessageEvent, ModelSpec, StreamOptions};
use crate::base::AdapterBase;
use crate::convert;
use crate::openai_compat::{
OaiChatRequest, OaiConverter, OaiStreamOptions, build_oai_tools, parse_oai_sse_stream,
};
pub struct OaiAdapterShell {
provider: &'static str,
base: AdapterBase,
chat_completions_path: &'static str,
}
impl OaiAdapterShell {
#[cfg(any(feature = "openai-compat", feature = "mistral"))]
pub(crate) fn new(
provider: &'static str,
base_url: impl Into<String>,
api_key: impl Into<String>,
) -> Self {
Self {
provider,
base: AdapterBase::new(base_url, api_key),
chat_completions_path: "/v1/chat/completions",
}
}
#[cfg(any(test, feature = "azure"))]
pub(crate) fn new_with_path(
provider: &'static str,
base_url: impl Into<String>,
api_key: impl Into<String>,
chat_completions_path: &'static str,
) -> Self {
Self {
provider,
base: AdapterBase::new(base_url, api_key),
chat_completions_path,
}
}
#[cfg(any(test, feature = "azure"))]
pub(crate) fn base_url(&self) -> &str {
&self.base.base_url
}
#[cfg(feature = "azure")]
pub(crate) const fn client(&self) -> &reqwest::Client {
&self.base.client
}
#[cfg(any(feature = "openai-compat", feature = "mistral"))]
pub(crate) fn fmt_debug(
&self,
name: &'static str,
f: &mut std::fmt::Formatter<'_>,
) -> std::fmt::Result {
f.debug_struct(name)
.field("base_url", &self.base.base_url)
.field("api_key", &"[REDACTED]")
.finish_non_exhaustive()
}
#[cfg(any(feature = "azure", feature = "mistral"))]
pub(crate) const fn provider(&self) -> &'static str {
self.provider
}
pub(crate) fn chat_completions_url(&self) -> String {
format!("{}{}", self.base.base_url, self.chat_completions_path)
}
#[cfg(any(feature = "openai-compat", feature = "mistral"))]
pub(crate) fn api_key<'a>(&'a self, options: &'a StreamOptions) -> &'a str {
options.api_key.as_deref().unwrap_or(&self.base.api_key)
}
#[cfg(feature = "mistral")]
pub(crate) fn post_json_request<T: Serialize>(
&self,
url: &str,
body: &T,
options: &StreamOptions,
) -> reqwest::RequestBuilder {
self.base
.client
.post(url)
.header("Authorization", format!("Bearer {}", self.api_key(options)))
.json(body)
}
#[cfg(feature = "openai-compat")]
pub(crate) fn stream<'a>(
&'a self,
model: &'a ModelSpec,
context: &'a AgentContext,
options: &'a StreamOptions,
cancellation_token: CancellationToken,
) -> Pin<Box<dyn Stream<Item = AssistantMessageEvent> + Send + 'a>> {
let url = self.chat_completions_url();
debug!(
provider = self.provider,
%url,
model = %model.model_id,
messages = context.messages.len(),
"sending OAI-compatible request"
);
let request = prepare_oai_request(&self.base.client, &url, model, context, options)
.header("Authorization", format!("Bearer {}", self.api_key(options)));
Box::pin(oai_send_and_parse(
request,
self.provider,
cancellation_token,
options.on_raw_payload.clone(),
|_, _| None,
))
}
}
pub fn prepare_oai_request(
client: &reqwest::Client,
url: &str,
model: &ModelSpec,
context: &AgentContext,
options: &StreamOptions,
) -> reqwest::RequestBuilder {
let messages =
convert::convert_messages::<OaiConverter>(&context.messages, &context.system_prompt);
let (tools, tool_choice) = build_oai_tools(&context.tools);
let body = OaiChatRequest {
model: model.model_id.clone(),
messages,
stream: true,
stream_options: Some(OaiStreamOptions {
include_usage: true,
}),
temperature: options.temperature,
max_tokens: options.max_tokens,
tools,
tool_choice,
};
client.post(url).json(&body)
}
pub fn oai_send_and_parse<'a>(
request: reqwest::RequestBuilder,
provider: &'static str,
cancellation_token: tokio_util::sync::CancellationToken,
on_raw_payload: Option<swink_agent::OnRawPayload>,
classify_error: impl Fn(u16, &str) -> Option<AssistantMessageEvent> + Send + 'a,
) -> impl Stream<Item = AssistantMessageEvent> + Send + 'a {
stream::once(async move {
let response = match tokio::select! {
() = cancellation_token.cancelled() => {
return stream::iter(Vec::from(crate::base::pre_stream_error(
crate::base::cancelled_error("operation cancelled"),
)))
.left_stream();
}
response = request.send() => response
} {
Ok(resp) => resp,
Err(e) => {
return stream::iter(vec![
AssistantMessageEvent::Start,
AssistantMessageEvent::error_network(format!(
"{provider} connection error: {e}"
)),
])
.left_stream();
}
};
let status = response.status();
if !status.is_success() {
let code = status.as_u16();
let body = match crate::base::read_error_body_or_cancelled(
response,
&cancellation_token,
"operation cancelled",
)
.await
{
Ok(body) => body,
Err(event) => {
return stream::iter(Vec::from(crate::base::pre_stream_error(event)))
.left_stream();
}
};
warn!(status = code, "{provider} HTTP error");
if let Some(event) = classify_error(code, &body) {
return stream::iter(vec![AssistantMessageEvent::Start, event]).left_stream();
}
let event = crate::classify::error_event_from_status(code, &body, provider);
return stream::iter(vec![AssistantMessageEvent::Start, event]).left_stream();
}
parse_oai_sse_stream(response, cancellation_token, provider, on_raw_payload).right_stream()
})
.flatten()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn custom_chat_path_is_used() {
let shell = OaiAdapterShell::new_with_path(
"Azure",
"https://example.openai.azure.com/openai/deployments/gpt-4/",
"",
"/chat/completions",
);
assert_eq!(
shell.chat_completions_url(),
"https://example.openai.azure.com/openai/deployments/gpt-4/chat/completions"
);
}
}