use std::fmt;
use crate::controller::session::LLMProvider;
pub const DEFAULT_MAX_TOKENS: u32 = 4096;
#[derive(Debug, Clone)]
pub struct StatelessConfig {
pub provider: LLMProvider,
pub api_key: String,
pub model: String,
pub base_url: Option<String>,
pub max_tokens: u32,
pub system_prompt: Option<String>,
pub temperature: Option<f32>,
pub azure_resource: Option<String>,
pub azure_deployment: Option<String>,
pub azure_api_version: Option<String>,
pub bedrock_region: Option<String>,
pub bedrock_access_key_id: Option<String>,
pub bedrock_secret_access_key: Option<String>,
pub bedrock_session_token: Option<String>,
}
impl StatelessConfig {
pub fn anthropic(api_key: impl Into<String>, model: impl Into<String>) -> Self {
Self {
provider: LLMProvider::Anthropic,
api_key: api_key.into(),
model: model.into(),
base_url: None,
max_tokens: DEFAULT_MAX_TOKENS,
system_prompt: None,
temperature: None,
azure_resource: None,
azure_deployment: None,
azure_api_version: None,
bedrock_region: None,
bedrock_access_key_id: None,
bedrock_secret_access_key: None,
bedrock_session_token: None,
}
}
pub fn openai(api_key: impl Into<String>, model: impl Into<String>) -> Self {
Self {
provider: LLMProvider::OpenAI,
api_key: api_key.into(),
model: model.into(),
base_url: None,
max_tokens: DEFAULT_MAX_TOKENS,
system_prompt: None,
temperature: None,
azure_resource: None,
azure_deployment: None,
azure_api_version: None,
bedrock_region: None,
bedrock_access_key_id: None,
bedrock_secret_access_key: None,
bedrock_session_token: None,
}
}
pub fn openai_compatible(
api_key: impl Into<String>,
model: impl Into<String>,
base_url: impl Into<String>,
) -> Self {
Self {
provider: LLMProvider::OpenAI,
api_key: api_key.into(),
model: model.into(),
base_url: Some(base_url.into()),
max_tokens: DEFAULT_MAX_TOKENS,
system_prompt: None,
temperature: None,
azure_resource: None,
azure_deployment: None,
azure_api_version: None,
bedrock_region: None,
bedrock_access_key_id: None,
bedrock_secret_access_key: None,
bedrock_session_token: None,
}
}
pub fn google(api_key: impl Into<String>, model: impl Into<String>) -> Self {
Self {
provider: LLMProvider::Google,
api_key: api_key.into(),
model: model.into(),
base_url: None,
max_tokens: DEFAULT_MAX_TOKENS,
system_prompt: None,
temperature: None,
azure_resource: None,
azure_deployment: None,
azure_api_version: None,
bedrock_region: None,
bedrock_access_key_id: None,
bedrock_secret_access_key: None,
bedrock_session_token: None,
}
}
pub fn azure_openai(
api_key: impl Into<String>,
resource: impl Into<String>,
deployment: impl Into<String>,
) -> Self {
Self {
provider: LLMProvider::OpenAI,
api_key: api_key.into(),
model: String::new(),
base_url: None,
max_tokens: DEFAULT_MAX_TOKENS,
system_prompt: None,
temperature: None,
azure_resource: Some(resource.into()),
azure_deployment: Some(deployment.into()),
azure_api_version: Some("2024-10-21".to_string()),
bedrock_region: None,
bedrock_access_key_id: None,
bedrock_secret_access_key: None,
bedrock_session_token: None,
}
}
pub fn with_azure_api_version(mut self, version: impl Into<String>) -> Self {
self.azure_api_version = Some(version.into());
self
}
pub fn cohere(api_key: impl Into<String>, model: impl Into<String>) -> Self {
Self {
provider: LLMProvider::Cohere,
api_key: api_key.into(),
model: model.into(),
base_url: None,
max_tokens: DEFAULT_MAX_TOKENS,
system_prompt: None,
temperature: None,
azure_resource: None,
azure_deployment: None,
azure_api_version: None,
bedrock_region: None,
bedrock_access_key_id: None,
bedrock_secret_access_key: None,
bedrock_session_token: None,
}
}
pub fn bedrock(
access_key_id: impl Into<String>,
secret_access_key: impl Into<String>,
region: impl Into<String>,
model: impl Into<String>,
) -> Self {
Self {
provider: LLMProvider::Bedrock,
api_key: String::new(), model: model.into(),
base_url: None,
max_tokens: DEFAULT_MAX_TOKENS,
system_prompt: None,
temperature: None,
azure_resource: None,
azure_deployment: None,
azure_api_version: None,
bedrock_region: Some(region.into()),
bedrock_access_key_id: Some(access_key_id.into()),
bedrock_secret_access_key: Some(secret_access_key.into()),
bedrock_session_token: None,
}
}
pub fn with_bedrock_session_token(mut self, token: impl Into<String>) -> Self {
self.bedrock_session_token = Some(token.into());
self
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = max_tokens;
self
}
pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.system_prompt = Some(prompt.into());
self
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = Some(temperature);
self
}
pub fn validate(&self) -> Result<(), StatelessError> {
if self.api_key.is_empty() {
return Err(StatelessError::MissingApiKey);
}
if self.model.is_empty() {
return Err(StatelessError::MissingModel);
}
Ok(())
}
}
#[derive(Debug, Clone, Default)]
pub struct StatelessResult {
pub text: String,
pub input_tokens: i64,
pub output_tokens: i64,
pub model: String,
pub stop_reason: Option<String>,
}
#[derive(Debug, Clone, PartialEq)]
pub enum StatelessError {
MissingApiKey,
MissingModel,
EmptyInput,
Cancelled,
StreamInterrupted,
ExecutionFailed { op: String, message: String },
}
impl fmt::Display for StatelessError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
StatelessError::MissingApiKey => write!(f, "stateless: API key is required"),
StatelessError::MissingModel => write!(f, "stateless: model is required"),
StatelessError::EmptyInput => write!(f, "stateless: input cannot be empty"),
StatelessError::Cancelled => write!(f, "stateless: request cancelled"),
StatelessError::StreamInterrupted => {
write!(f, "stateless: stream interrupted by callback")
}
StatelessError::ExecutionFailed { op, message } => {
write!(f, "stateless: {}: {}", op, message)
}
}
}
}
impl std::error::Error for StatelessError {}
#[derive(Debug, Clone, Default)]
pub struct RequestOptions {
pub model: Option<String>,
pub max_tokens: Option<u32>,
pub system_prompt: Option<String>,
pub temperature: Option<f32>,
}
impl RequestOptions {
pub fn new() -> Self {
Self::default()
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = Some(model.into());
self
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.system_prompt = Some(prompt.into());
self
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = Some(temperature);
self
}
}
pub type StreamCallback = Box<dyn FnMut(&str) -> Result<(), ()> + Send>;