pub mod anthropic;
pub mod bedrock;
pub mod validation;
pub mod gemini;
pub mod litellm;
pub mod llamaapi;
pub mod llamacpp;
pub mod mistral;
pub mod ollama;
pub mod openai;
pub mod sagemaker;
pub mod writer;
use std::pin::Pin;
use async_trait::async_trait;
use futures::Stream;
use crate::types::{
content::{Message, SystemContentBlock},
errors::StrandsError,
streaming::StreamEvent,
tools::{ToolChoice, ToolSpec},
};
#[derive(Debug, Clone, Default)]
pub struct ModelConfig {
pub model_id: String,
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub stop_sequences: Option<Vec<String>>,
pub additional: std::collections::HashMap<String, serde_json::Value>,
}
impl ModelConfig {
pub fn new(model_id: impl Into<String>) -> Self {
Self { model_id: model_id.into(), ..Default::default() }
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = Some(temperature);
self
}
pub fn with_top_p(mut self, top_p: f32) -> Self {
self.top_p = Some(top_p);
self
}
}
pub type StreamEventStream<'a> = Pin<Box<dyn Stream<Item = Result<StreamEvent, StrandsError>> + Send + 'a>>;
#[async_trait]
pub trait Model: Send + Sync {
fn config(&self) -> &ModelConfig;
fn update_config(&mut self, config: ModelConfig);
fn stream<'a>(
&'a self,
messages: &'a [Message],
tool_specs: Option<&'a [ToolSpec]>,
system_prompt: Option<&'a str>,
tool_choice: Option<ToolChoice>,
system_prompt_content: Option<&'a [SystemContentBlock]>,
) -> StreamEventStream<'a>;
}
#[async_trait]
pub trait ModelExt: Model {
async fn structured_output<T>(
&self,
messages: &[Message],
system_prompt: Option<&str>,
) -> Result<T, StrandsError>
where
T: serde::de::DeserializeOwned + schemars::JsonSchema + Send,
{
use futures::StreamExt;
let mut content = String::new();
let mut stream = self.stream(messages, None, system_prompt, None, None);
while let Some(event) = stream.next().await {
let event = event?;
if let Some(text) = event.as_text_delta() {
content.push_str(text);
}
}
serde_json::from_str(&content).map_err(|e| StrandsError::StructuredOutputError {
message: format!("Failed to parse structured output: {e}"),
})
}
}
impl<T: Model> ModelExt for T {}
pub use anthropic::AnthropicModel;
pub use bedrock::BedrockModel;
pub use gemini::{GeminiConfig, GeminiModel};
pub use litellm::{LiteLLMConfig, LiteLLMModel};
pub use llamaapi::{LlamaAPIConfig, LlamaAPIModel};
pub use llamacpp::{LlamaCppConfig, LlamaCppModel};
pub use mistral::{MistralConfig, MistralModel};
pub use ollama::OllamaModel;
pub use openai::OpenAIModel;
pub use sagemaker::{SageMakerEndpointConfig, SageMakerModel, SageMakerPayloadConfig};
pub use writer::{WriterConfig, WriterModel};
pub use validation::{
config_keys, validate_config_keys, warn_on_tool_choice_not_supported,
};