use tokio_util::sync::CancellationToken;
use crate::client::LLMClient;
use crate::client::models::{Message as LLMMessage, MessageOptions, StreamEvent};
use crate::client::providers::anthropic::AnthropicProvider;
use crate::client::providers::bedrock::{BedrockCredentials, BedrockProvider};
use crate::client::providers::cohere::CohereProvider;
use crate::client::providers::gemini::GeminiProvider;
use crate::client::providers::openai::OpenAIProvider;
use crate::controller::session::LLMProvider;
use super::types::{
DEFAULT_MAX_TOKENS, RequestOptions, StatelessConfig, StatelessError, StatelessResult,
StreamCallback,
};
pub struct StatelessExecutor {
client: LLMClient,
config: StatelessConfig,
}
impl StatelessExecutor {
pub fn new(config: StatelessConfig) -> Result<Self, StatelessError> {
config.validate()?;
let client = match config.provider {
LLMProvider::Anthropic => {
let provider = AnthropicProvider::new(config.api_key.clone(), config.model.clone());
LLMClient::new(Box::new(provider)).map_err(|e| StatelessError::ExecutionFailed {
op: "init_client".to_string(),
message: format!("failed to initialize LLM client: {}", e),
})?
}
LLMProvider::OpenAI => {
let provider = if let (Some(resource), Some(deployment)) =
(&config.azure_resource, &config.azure_deployment)
{
let api_version = config
.azure_api_version
.clone()
.unwrap_or_else(|| "2024-10-21".to_string());
OpenAIProvider::azure(
config.api_key.clone(),
resource.clone(),
deployment.clone(),
api_version,
)
} else if let Some(base_url) = &config.base_url {
OpenAIProvider::with_base_url(
config.api_key.clone(),
config.model.clone(),
base_url.clone(),
)
} else {
OpenAIProvider::new(config.api_key.clone(), config.model.clone())
};
LLMClient::new(Box::new(provider)).map_err(|e| StatelessError::ExecutionFailed {
op: "init_client".to_string(),
message: format!("failed to initialize LLM client: {}", e),
})?
}
LLMProvider::Google => {
let provider = GeminiProvider::new(config.api_key.clone(), config.model.clone());
LLMClient::new(Box::new(provider)).map_err(|e| StatelessError::ExecutionFailed {
op: "init_client".to_string(),
message: format!("failed to initialize LLM client: {}", e),
})?
}
LLMProvider::Cohere => {
let provider = CohereProvider::new(config.api_key.clone(), config.model.clone());
LLMClient::new(Box::new(provider)).map_err(|e| StatelessError::ExecutionFailed {
op: "init_client".to_string(),
message: format!("failed to initialize LLM client: {}", e),
})?
}
LLMProvider::Bedrock => {
let region = config.bedrock_region.clone().ok_or_else(|| {
StatelessError::ExecutionFailed {
op: "init_client".to_string(),
message: "Bedrock requires bedrock_region".to_string(),
}
})?;
let access_key_id = config.bedrock_access_key_id.clone().ok_or_else(|| {
StatelessError::ExecutionFailed {
op: "init_client".to_string(),
message: "Bedrock requires bedrock_access_key_id".to_string(),
}
})?;
let secret_access_key =
config.bedrock_secret_access_key.clone().ok_or_else(|| {
StatelessError::ExecutionFailed {
op: "init_client".to_string(),
message: "Bedrock requires bedrock_secret_access_key".to_string(),
}
})?;
let credentials = match &config.bedrock_session_token {
Some(token) => BedrockCredentials::with_session_token(
access_key_id,
secret_access_key,
token.clone(),
),
None => BedrockCredentials::new(access_key_id, secret_access_key),
};
let provider = BedrockProvider::new(credentials, region, config.model.clone());
LLMClient::new(Box::new(provider)).map_err(|e| StatelessError::ExecutionFailed {
op: "init_client".to_string(),
message: format!("failed to initialize LLM client: {}", e),
})?
}
};
Ok(Self { client, config })
}
pub async fn execute(
&self,
input: &str,
options: Option<RequestOptions>,
) -> Result<StatelessResult, StatelessError> {
if input.is_empty() {
return Err(StatelessError::EmptyInput);
}
let msg_opts = self.build_message_options(options.as_ref());
let mut messages = Vec::new();
let system_prompt = options
.as_ref()
.and_then(|o| o.system_prompt.as_ref())
.or(self.config.system_prompt.as_ref());
if let Some(prompt) = system_prompt {
messages.push(LLMMessage::system(prompt));
}
messages.push(LLMMessage::user(input));
let response = self
.client
.send_message(&messages, &msg_opts)
.await
.map_err(|e| StatelessError::ExecutionFailed {
op: "send_message".to_string(),
message: e.to_string(),
})?;
let text = self.extract_text(&response);
Ok(StatelessResult {
text,
input_tokens: 0, output_tokens: 0, model: self.config.model.clone(),
stop_reason: None,
})
}
pub async fn execute_stream(
&self,
input: &str,
mut callback: StreamCallback,
options: Option<RequestOptions>,
cancel_token: Option<CancellationToken>,
) -> Result<StatelessResult, StatelessError> {
use futures::StreamExt;
if input.is_empty() {
return Err(StatelessError::EmptyInput);
}
let msg_opts = self.build_message_options(options.as_ref());
let mut messages = Vec::new();
let system_prompt = options
.as_ref()
.and_then(|o| o.system_prompt.as_ref())
.or(self.config.system_prompt.as_ref());
if let Some(prompt) = system_prompt {
messages.push(LLMMessage::system(prompt));
}
messages.push(LLMMessage::user(input));
let mut stream = self
.client
.send_message_stream(&messages, &msg_opts)
.await
.map_err(|e| StatelessError::ExecutionFailed {
op: "create_stream".to_string(),
message: e.to_string(),
})?;
let mut result = StatelessResult {
model: self.config.model.clone(),
..Default::default()
};
let mut text_builder = String::new();
let cancel = cancel_token.unwrap_or_default();
loop {
tokio::select! {
_ = cancel.cancelled() => {
return Err(StatelessError::Cancelled);
}
event = stream.next() => {
match event {
Some(Ok(stream_event)) => {
match stream_event {
StreamEvent::MessageStart { model, .. } => {
result.model = model;
}
StreamEvent::TextDelta { text, .. } => {
text_builder.push_str(&text);
if callback(&text).is_err() {
return Err(StatelessError::StreamInterrupted);
}
}
StreamEvent::MessageDelta { stop_reason, usage } => {
if let Some(usage) = usage {
result.input_tokens = usage.input_tokens as i64;
result.output_tokens = usage.output_tokens as i64;
}
result.stop_reason = stop_reason;
}
StreamEvent::MessageStop => {
break;
}
_ => {}
}
}
Some(Err(e)) => {
return Err(StatelessError::ExecutionFailed {
op: "streaming".to_string(),
message: e.to_string(),
});
}
None => {
break;
}
}
}
}
}
result.text = text_builder;
Ok(result)
}
fn build_message_options(&self, opts: Option<&RequestOptions>) -> MessageOptions {
let max_tokens = opts
.and_then(|o| o.max_tokens)
.unwrap_or(if self.config.max_tokens > 0 {
self.config.max_tokens
} else {
DEFAULT_MAX_TOKENS
});
let temperature = opts.and_then(|o| o.temperature).or(self.config.temperature);
MessageOptions {
max_tokens: Some(max_tokens),
temperature,
..Default::default()
}
}
fn extract_text(&self, message: &LLMMessage) -> String {
use crate::client::models::Content;
let mut text = String::new();
for block in &message.content {
if let Content::Text(t) = block {
text.push_str(t);
}
}
text
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_validation() {
let config = StatelessConfig {
provider: LLMProvider::Anthropic,
api_key: "".to_string(),
model: "claude-3".to_string(),
base_url: None,
max_tokens: 4096,
system_prompt: None,
temperature: None,
azure_resource: None,
azure_deployment: None,
azure_api_version: None,
bedrock_region: None,
bedrock_access_key_id: None,
bedrock_secret_access_key: None,
bedrock_session_token: None,
};
assert!(config.validate().is_err());
let config = StatelessConfig {
provider: LLMProvider::Anthropic,
api_key: "test-key".to_string(),
model: "".to_string(),
base_url: None,
max_tokens: 4096,
system_prompt: None,
temperature: None,
azure_resource: None,
azure_deployment: None,
azure_api_version: None,
bedrock_region: None,
bedrock_access_key_id: None,
bedrock_secret_access_key: None,
bedrock_session_token: None,
};
assert!(config.validate().is_err());
let config = StatelessConfig::anthropic("test-key", "claude-3");
assert!(config.validate().is_ok());
}
#[test]
fn test_request_options_builder() {
let opts = RequestOptions::new()
.with_model("gpt-4")
.with_max_tokens(2048)
.with_system_prompt("Be helpful")
.with_temperature(0.7);
assert_eq!(opts.model, Some("gpt-4".to_string()));
assert_eq!(opts.max_tokens, Some(2048));
assert_eq!(opts.system_prompt, Some("Be helpful".to_string()));
assert_eq!(opts.temperature, Some(0.7));
}
#[test]
fn test_config_builder() {
let config = StatelessConfig::anthropic("key", "model")
.with_max_tokens(8192)
.with_system_prompt("You are helpful")
.with_temperature(0.5);
assert_eq!(config.api_key, "key");
assert_eq!(config.model, "model");
assert_eq!(config.max_tokens, 8192);
assert_eq!(config.system_prompt, Some("You are helpful".to_string()));
assert_eq!(config.temperature, Some(0.5));
}
}