use crate::config::{Capabilities, RetryConfig};
use crate::error::Result;
use crate::request::CompletionRequest;
use crate::response::{CompletionChunk, CompletionResponse};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::borrow::Cow;
use std::time::Duration;
pub type Headers = Vec<(Cow<'static, str>, Cow<'static, str>)>;
pub mod headers {
pub const AUTHORIZATION: &str = "Authorization";
pub const CONTENT_TYPE: &str = "Content-Type";
pub const X_API_KEY: &str = "x-api-key";
}
#[async_trait]
pub trait Provider: Send + Sync {
fn name(&self) -> &str;
fn transform_request(&self, req: &CompletionRequest) -> Result<ProviderRequest>;
async fn execute(&self, req: ProviderRequest) -> Result<ProviderResponse>;
fn transform_response(&self, resp: ProviderResponse) -> Result<CompletionResponse>;
fn retry_config(&self) -> RetryConfig {
RetryConfig::default()
}
fn capabilities(&self) -> Capabilities {
Capabilities::default()
}
fn timeout(&self) -> Duration {
Duration::from_secs(30)
}
async fn execute_stream(
&self,
mut req: ProviderRequest,
) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
if let Value::Object(map) = &mut req.body {
if let Some(stream_value) = map.get_mut("stream") {
*stream_value = Value::Bool(false);
}
}
let provider_response = self.execute(req).await?;
let response = self.transform_response(provider_response)?;
struct SingleChunkStream {
chunk: Option<Result<CompletionChunk>>,
}
impl futures_core::Stream for SingleChunkStream {
type Item = Result<CompletionChunk>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
std::task::Poll::Ready(self.chunk.take())
}
}
let choices = response
.choices
.into_iter()
.map(|choice| crate::response::ChoiceDelta {
index: choice.index,
delta: crate::response::MessageDelta {
role: Some(choice.message.role),
content: Some(choice.message.content),
reasoning_content: None,
tool_calls: None,
},
finish_reason: Some(choice.finish_reason),
})
.collect();
let chunk = CompletionChunk {
id: response.id,
model: response.model,
choices,
created: response.created,
usage: Some(response.usage),
};
Ok(Box::new(SingleChunkStream {
chunk: Some(Ok(chunk)),
}))
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ProviderRequest {
pub url: String,
#[serde(with = "header_serde")]
pub headers: Headers,
pub body: serde_json::Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub timeout: Option<Duration>,
}
mod header_serde {
use super::Headers;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::borrow::Cow;
pub fn serialize<S>(
headers: &[(Cow<'static, str>, Cow<'static, str>)],
serializer: S,
) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let string_headers: Vec<(&str, &str)> = headers
.iter()
.map(|(k, v)| (k.as_ref(), v.as_ref()))
.collect();
string_headers.serialize(serializer)
}
pub fn deserialize<'de, D>(deserializer: D) -> std::result::Result<Headers, D::Error>
where
D: Deserializer<'de>,
{
let string_headers: Vec<(String, String)> = Vec::deserialize(deserializer)?;
Ok(string_headers
.into_iter()
.map(|(k, v)| (Cow::Owned(k), Cow::Owned(v)))
.collect())
}
}
impl ProviderRequest {
pub fn new(url: impl Into<String>) -> Self {
Self {
url: url.into(),
headers: Vec::new(),
body: serde_json::Value::Null,
timeout: None,
}
}
pub fn with_header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
self.headers
.push((Cow::Owned(name.into()), Cow::Owned(value.into())));
self
}
pub fn with_static_header(mut self, name: &'static str, value: &'static str) -> Self {
self.headers
.push((Cow::Borrowed(name), Cow::Borrowed(value)));
self
}
pub fn with_body(mut self, body: serde_json::Value) -> Self {
self.body = body;
self
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ProviderResponse {
pub status: u16,
pub body: serde_json::Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub headers: Option<Vec<(String, String)>>,
}
impl ProviderResponse {
pub fn new(status: u16, body: serde_json::Value) -> Self {
Self {
status,
body,
headers: None,
}
}
pub fn is_success(&self) -> bool {
(200..300).contains(&self.status)
}
pub fn is_client_error(&self) -> bool {
(400..500).contains(&self.status)
}
pub fn is_server_error(&self) -> bool {
(500..600).contains(&self.status)
}
pub fn with_headers(mut self, headers: Vec<(String, String)>) -> Self {
self.headers = Some(headers);
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_provider_request_builder() {
let req = ProviderRequest::new("https://api.example.com/v1/completions")
.with_header("Authorization", "Bearer sk-test")
.with_header("Content-Type", "application/json")
.with_body(serde_json::json!({"model": "test"}))
.with_timeout(Duration::from_secs(30));
assert_eq!(req.url, "https://api.example.com/v1/completions");
assert_eq!(req.headers.len(), 2);
assert_eq!(req.body["model"], "test");
assert_eq!(req.timeout, Some(Duration::from_secs(30)));
}
#[test]
fn test_provider_response_status_checks() {
let resp = ProviderResponse::new(200, serde_json::json!({}));
assert!(resp.is_success());
assert!(!resp.is_client_error());
assert!(!resp.is_server_error());
let resp = ProviderResponse::new(404, serde_json::json!({}));
assert!(!resp.is_success());
assert!(resp.is_client_error());
assert!(!resp.is_server_error());
let resp = ProviderResponse::new(500, serde_json::json!({}));
assert!(!resp.is_success());
assert!(!resp.is_client_error());
assert!(resp.is_server_error());
}
#[test]
fn test_provider_request_serialization() {
let req = ProviderRequest::new("https://api.example.com")
.with_header("X-Test", "value")
.with_body(serde_json::json!({"key": "value"}));
let json = serde_json::to_string(&req).unwrap();
let parsed: ProviderRequest = serde_json::from_str(&json).unwrap();
assert_eq!(req, parsed);
}
#[test]
fn test_provider_response_serialization() {
let resp = ProviderResponse::new(200, serde_json::json!({"result": "success"}))
.with_headers(vec![("X-Request-ID".to_string(), "123".to_string())]);
let json = serde_json::to_string(&resp).unwrap();
let parsed: ProviderResponse = serde_json::from_str(&json).unwrap();
assert_eq!(resp, parsed);
}
#[test]
fn test_provider_object_safety() {
fn _assert_object_safe(_: &dyn Provider) {}
}
}