use crate::error::AgentRuntimeError;
use async_trait::async_trait;
#[derive(Debug)]
pub struct CompletionOptions<'a> {
pub model: &'a str,
pub max_tokens: Option<usize>,
pub temperature: Option<f32>,
pub timeout: Option<std::time::Duration>,
pub stop_sequences: Vec<String>,
}
impl<'a> CompletionOptions<'a> {
pub fn new(model: &'a str) -> Self {
Self {
model,
max_tokens: None,
temperature: None,
timeout: None,
stop_sequences: vec![],
}
}
pub fn with_max_tokens(mut self, n: usize) -> Self {
self.max_tokens = Some(n);
self
}
pub fn with_temperature(mut self, t: f32) -> Self {
self.temperature = Some(t);
self
}
pub fn with_timeout(mut self, d: std::time::Duration) -> Self {
self.timeout = Some(d);
self
}
pub fn with_stop_sequences(mut self, sequences: Vec<String>) -> Self {
self.stop_sequences = sequences;
self
}
pub fn with_timeout_secs(self, secs: u64) -> Self {
self.with_timeout(std::time::Duration::from_secs(secs))
}
pub fn with_timeout_ms(self, ms: u64) -> Self {
self.with_timeout(std::time::Duration::from_millis(ms))
}
pub fn has_stop_sequences(&self) -> bool {
!self.stop_sequences.is_empty()
}
pub fn stop_sequence_count(&self) -> usize {
self.stop_sequences.len()
}
}
#[async_trait]
pub trait LlmProvider: Send + Sync {
async fn complete(&self, prompt: &str, model: &str) -> Result<String, AgentRuntimeError>;
async fn complete_with_options(
&self,
prompt: &str,
options: CompletionOptions<'_>,
) -> Result<String, AgentRuntimeError> {
self.complete(prompt, options.model).await
}
async fn stream_complete(
&self,
prompt: &str,
model: &str,
) -> Result<tokio::sync::mpsc::Receiver<Result<String, AgentRuntimeError>>, AgentRuntimeError>
{
let result = self.complete(prompt, model).await;
let (tx, rx) = tokio::sync::mpsc::channel(64);
let _ = tx.send(result).await;
Ok(rx)
}
}
#[cfg(feature = "anthropic")]
pub struct AnthropicProvider {
api_key: String,
api_url: String,
client: reqwest::Client,
stream_semaphore: std::sync::Arc<tokio::sync::Semaphore>,
stream_max_tokens: Option<u32>,
}
#[cfg(feature = "anthropic")]
impl AnthropicProvider {
const DEFAULT_API_URL: &'static str = "https://api.anthropic.com/v1/messages";
const API_VERSION: &'static str = "2023-06-01";
const MAX_TOKENS: u32 = 1024;
const DEFAULT_STREAM_CONCURRENCY: usize = 32;
pub fn new(api_key: impl Into<String>) -> Self {
Self {
api_key: api_key.into(),
api_url: Self::DEFAULT_API_URL.to_owned(),
client: reqwest::Client::new(),
stream_semaphore: std::sync::Arc::new(tokio::sync::Semaphore::new(
Self::DEFAULT_STREAM_CONCURRENCY,
)),
stream_max_tokens: None,
}
}
pub fn with_base_url(api_key: impl Into<String>, api_url: impl Into<String>) -> Self {
Self {
api_key: api_key.into(),
api_url: api_url.into(),
client: reqwest::Client::new(),
stream_semaphore: std::sync::Arc::new(tokio::sync::Semaphore::new(
Self::DEFAULT_STREAM_CONCURRENCY,
)),
stream_max_tokens: None,
}
}
pub fn with_max_concurrent_streams(api_key: impl Into<String>, max: usize) -> Self {
Self {
api_key: api_key.into(),
api_url: Self::DEFAULT_API_URL.to_owned(),
client: reqwest::Client::new(),
stream_semaphore: std::sync::Arc::new(tokio::sync::Semaphore::new(max)),
stream_max_tokens: None,
}
}
pub fn with_stream_max_tokens(mut self, max_tokens: u32) -> Self {
self.stream_max_tokens = Some(max_tokens);
self
}
}
#[cfg(feature = "anthropic")]
#[async_trait]
impl LlmProvider for AnthropicProvider {
async fn complete(&self, prompt: &str, model: &str) -> Result<String, AgentRuntimeError> {
self.complete_with_options(prompt, CompletionOptions::new(model))
.await
}
#[tracing::instrument(skip(self, prompt, options), fields(model = options.model, provider = "anthropic"))]
async fn complete_with_options(
&self,
prompt: &str,
options: CompletionOptions<'_>,
) -> Result<String, AgentRuntimeError> {
let max_tokens = options
.max_tokens
.unwrap_or(Self::MAX_TOKENS as usize) as u32;
let mut body = serde_json::json!({
"model": options.model,
"max_tokens": max_tokens,
"messages": [{ "role": "user", "content": prompt }]
});
if let Some(t) = options.temperature {
body["temperature"] = serde_json::json!(t);
}
let mut req = self
.client
.post(&self.api_url)
.header("x-api-key", &self.api_key)
.header("anthropic-version", Self::API_VERSION)
.header("content-type", "application/json")
.json(&body);
if let Some(timeout) = options.timeout {
req = req.timeout(timeout);
}
let response = req
.send()
.await
.map_err(|e| AgentRuntimeError::Provider(format!("Anthropic request failed: {e}")))?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await.unwrap_or_default();
return Err(AgentRuntimeError::Provider(format!(
"Anthropic API error {status}: {text}"
)));
}
let json: serde_json::Value = response
.json()
.await
.map_err(|e| AgentRuntimeError::Provider(format!("Anthropic parse failed: {e}")))?;
let text = json["content"]
.as_array()
.and_then(|arr| arr.first())
.and_then(|block| block["text"].as_str())
.ok_or_else(|| {
AgentRuntimeError::Provider("Anthropic response missing content[0].text".into())
})?;
Ok(text.to_owned())
}
async fn stream_complete(
&self,
prompt: &str,
model: &str,
) -> Result<tokio::sync::mpsc::Receiver<Result<String, AgentRuntimeError>>, AgentRuntimeError>
{
let max_tokens = self.stream_max_tokens.unwrap_or(Self::MAX_TOKENS);
let body = serde_json::json!({
"model": model,
"max_tokens": max_tokens,
"stream": true,
"messages": [{ "role": "user", "content": prompt }]
});
let response = self
.client
.post(&self.api_url)
.header("x-api-key", &self.api_key)
.header("anthropic-version", Self::API_VERSION)
.header("content-type", "application/json")
.json(&body)
.send()
.await
.map_err(|e| {
AgentRuntimeError::Provider(format!("Anthropic stream request failed: {e}"))
})?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await.unwrap_or_default();
return Err(AgentRuntimeError::Provider(format!(
"Anthropic stream API error {status}: {text}"
)));
}
let (tx, rx) = tokio::sync::mpsc::channel::<Result<String, AgentRuntimeError>>(32);
let permit = std::sync::Arc::clone(&self.stream_semaphore)
.acquire_owned()
.await
.map_err(|_| {
AgentRuntimeError::Provider("Anthropic stream semaphore closed".into())
})?;
tokio::spawn(async move {
let _permit = permit;
let mut response = response;
let mut buffer = String::new();
loop {
match response.chunk().await {
Ok(Some(chunk)) => {
match String::from_utf8(chunk.to_vec()) {
Ok(s) => buffer.push_str(&s),
Err(e) => {
let _ = tx
.send(Err(AgentRuntimeError::Provider(format!(
"Anthropic stream: invalid UTF-8 in chunk: {e}"
))))
.await;
return;
}
}
while let Some(newline) = buffer.find('\n') {
let line = buffer[..newline].trim().to_owned();
buffer = buffer[newline + 1..].to_owned();
if let Some(data) = line.strip_prefix("data: ") {
if data == "[DONE]" {
return;
}
if let Ok(json) =
serde_json::from_str::<serde_json::Value>(data)
{
if let Some(delta) = json["delta"]["text"].as_str() {
if tx.send(Ok(delta.to_owned())).await.is_err() {
return;
}
}
}
}
}
}
Ok(None) => break,
Err(e) => {
let _ = tx
.send(Err(AgentRuntimeError::Provider(format!(
"Anthropic stream chunk error: {e}"
))))
.await;
return;
}
}
}
});
Ok(rx)
}
}
#[cfg(feature = "openai")]
pub struct OpenAiProvider {
api_key: String,
base_url: String,
client: reqwest::Client,
stream_semaphore: std::sync::Arc<tokio::sync::Semaphore>,
}
#[cfg(feature = "openai")]
impl OpenAiProvider {
const DEFAULT_BASE_URL: &'static str = "https://api.openai.com/v1";
const DEFAULT_STREAM_CONCURRENCY: usize = 32;
pub fn new(api_key: impl Into<String>) -> Self {
Self {
api_key: api_key.into(),
base_url: Self::DEFAULT_BASE_URL.to_owned(),
client: reqwest::Client::new(),
stream_semaphore: std::sync::Arc::new(tokio::sync::Semaphore::new(
Self::DEFAULT_STREAM_CONCURRENCY,
)),
}
}
pub fn with_base_url(api_key: impl Into<String>, base_url: impl Into<String>) -> Self {
Self {
api_key: api_key.into(),
base_url: base_url.into(),
client: reqwest::Client::new(),
stream_semaphore: std::sync::Arc::new(tokio::sync::Semaphore::new(
Self::DEFAULT_STREAM_CONCURRENCY,
)),
}
}
pub fn with_max_concurrent_streams(
api_key: impl Into<String>,
base_url: impl Into<String>,
max: usize,
) -> Self {
Self {
api_key: api_key.into(),
base_url: base_url.into(),
client: reqwest::Client::new(),
stream_semaphore: std::sync::Arc::new(tokio::sync::Semaphore::new(max)),
}
}
}
#[cfg(feature = "openai")]
#[async_trait]
impl LlmProvider for OpenAiProvider {
#[tracing::instrument(skip(self, prompt), fields(model, provider = "openai"))]
async fn complete(&self, prompt: &str, model: &str) -> Result<String, AgentRuntimeError> {
self.complete_with_options(prompt, CompletionOptions::new(model))
.await
}
#[tracing::instrument(skip(self, prompt, options), fields(model = options.model, provider = "openai"))]
async fn complete_with_options(
&self,
prompt: &str,
options: CompletionOptions<'_>,
) -> Result<String, AgentRuntimeError> {
let url = format!("{}/chat/completions", self.base_url);
let mut body = serde_json::json!({
"model": options.model,
"messages": [{ "role": "user", "content": prompt }]
});
if let Some(max_tokens) = options.max_tokens {
body["max_tokens"] = serde_json::json!(max_tokens);
}
if let Some(temp) = options.temperature {
body["temperature"] = serde_json::json!(temp);
}
let mut req = self
.client
.post(&url)
.bearer_auth(&self.api_key)
.header("content-type", "application/json")
.json(&body);
if let Some(timeout) = options.timeout {
req = req.timeout(timeout);
}
let response = req
.send()
.await
.map_err(|e| AgentRuntimeError::Provider(format!("OpenAI request failed: {e}")))?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await.unwrap_or_default();
return Err(AgentRuntimeError::Provider(format!(
"OpenAI API error {status}: {text}"
)));
}
let json: serde_json::Value = response
.json()
.await
.map_err(|e| AgentRuntimeError::Provider(format!("OpenAI parse failed: {e}")))?;
let text = json["choices"]
.as_array()
.and_then(|arr| arr.first())
.and_then(|choice| choice["message"]["content"].as_str())
.ok_or_else(|| {
AgentRuntimeError::Provider(
"OpenAI response missing choices[0].message.content".into(),
)
})?;
Ok(text.to_owned())
}
async fn stream_complete(
&self,
prompt: &str,
model: &str,
) -> Result<tokio::sync::mpsc::Receiver<Result<String, AgentRuntimeError>>, AgentRuntimeError>
{
let url = format!("{}/chat/completions", self.base_url);
let body = serde_json::json!({
"model": model,
"stream": true,
"messages": [{ "role": "user", "content": prompt }]
});
let mut response = self
.client
.post(&url)
.bearer_auth(&self.api_key)
.header("content-type", "application/json")
.json(&body)
.send()
.await
.map_err(|e| {
AgentRuntimeError::Provider(format!("OpenAI stream request failed: {e}"))
})?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await.unwrap_or_default();
return Err(AgentRuntimeError::Provider(format!(
"OpenAI stream API error {status}: {text}"
)));
}
let (tx, rx) = tokio::sync::mpsc::channel::<Result<String, AgentRuntimeError>>(32);
let permit = std::sync::Arc::clone(&self.stream_semaphore)
.acquire_owned()
.await
.map_err(|_| {
AgentRuntimeError::Provider("OpenAI stream semaphore closed".into())
})?;
tokio::spawn(async move {
let _permit = permit;
let mut buffer = String::new();
loop {
match response.chunk().await {
Ok(Some(chunk)) => {
let text = match std::str::from_utf8(&chunk) {
Ok(t) => t,
Err(e) => {
let _ = tx
.send(Err(AgentRuntimeError::Provider(format!(
"OpenAI stream chunk is not valid UTF-8: {e}"
))))
.await;
return;
}
};
buffer.push_str(text);
while let Some(newline) = buffer.find('\n') {
let line: String = buffer.drain(..=newline).collect();
let line = line.trim_end_matches(['\r', '\n']);
if let Some(data) = line.strip_prefix("data: ") {
if data == "[DONE]" {
return;
}
if let Ok(json) =
serde_json::from_str::<serde_json::Value>(data)
{
if let Some(content) = json["choices"]
.as_array()
.and_then(|c| c.first())
.and_then(|c| c["delta"]["content"].as_str())
{
if tx.send(Ok(content.to_owned())).await.is_err() {
return;
}
}
}
}
}
}
Ok(None) => break,
Err(e) => {
let _ = tx
.send(Err(AgentRuntimeError::Provider(format!(
"OpenAI stream read failed: {e}"
))))
.await;
return;
}
}
}
});
Ok(rx)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
struct StubProvider {
response: String,
}
#[async_trait]
impl LlmProvider for StubProvider {
async fn complete(&self, _prompt: &str, _model: &str) -> Result<String, AgentRuntimeError> {
Ok(self.response.clone())
}
}
#[tokio::test]
async fn test_stub_provider_returns_configured_response() {
let p = StubProvider {
response: "hello".into(),
};
let result = p.complete("prompt", "stub-model").await.unwrap();
assert_eq!(result, "hello");
}
#[tokio::test]
async fn test_llm_provider_is_object_safe() {
let p: Arc<dyn LlmProvider> = Arc::new(StubProvider {
response: "ok".into(),
});
let result = p.complete("test", "model").await.unwrap();
assert_eq!(result, "ok");
}
#[tokio::test]
async fn test_stub_provider_ignores_model_parameter() {
let p = StubProvider {
response: "42".into(),
};
let r1 = p.complete("q", "model-a").await.unwrap();
let r2 = p.complete("q", "model-b").await.unwrap();
assert_eq!(r1, r2);
}
#[tokio::test]
async fn test_stub_provider_stream_returns_single_chunk() {
let p = StubProvider {
response: "hello world".into(),
};
let mut rx = p.stream_complete("prompt", "model").await.unwrap();
let mut collected = String::new();
while let Some(chunk) = rx.recv().await {
collected.push_str(&chunk.unwrap());
}
assert_eq!(collected, "hello world");
}
#[tokio::test]
async fn test_stream_receiver_closes_after_completion() {
let p = StubProvider {
response: "done".into(),
};
let mut rx = p.stream_complete("prompt", "model").await.unwrap();
while let Some(_chunk) = rx.recv().await {}
assert!(rx.recv().await.is_none());
}
}