strands_agents/models/
llamaapi.rs

1//! Llama API model provider.
2//!
3//! This provider integrates with Meta's Llama API.
4//! See: https://llama.developer.meta.com/
5
6use crate::types::content::{Message, SystemContentBlock};
7use crate::types::errors::StrandsError;
8use crate::types::tools::{ToolChoice, ToolSpec};
9
10use super::{Model, ModelConfig, StreamEventStream};
11
12/// Configuration for Llama API models.
13#[derive(Debug, Clone, Default)]
14pub struct LlamaAPIConfig {
15    /// Model ID (e.g., "Llama-4-Maverick-17B-128E-Instruct-FP8").
16    pub model_id: String,
17    /// Repetition penalty.
18    pub repetition_penalty: Option<f64>,
19    /// Temperature for sampling.
20    pub temperature: Option<f64>,
21    /// Top-p for nucleus sampling.
22    pub top_p: Option<f64>,
23    /// Maximum completion tokens.
24    pub max_completion_tokens: Option<u32>,
25    /// Top-k for sampling.
26    pub top_k: Option<u32>,
27    /// API key for authentication.
28    pub api_key: Option<String>,
29}
30
31impl LlamaAPIConfig {
32    /// Create a new Llama API config.
33    pub fn new(model_id: impl Into<String>) -> Self {
34        Self {
35            model_id: model_id.into(),
36            ..Default::default()
37        }
38    }
39
40    /// Set temperature.
41    pub fn with_temperature(mut self, temperature: f64) -> Self {
42        self.temperature = Some(temperature);
43        self
44    }
45
46    /// Set top-p.
47    pub fn with_top_p(mut self, top_p: f64) -> Self {
48        self.top_p = Some(top_p);
49        self
50    }
51
52    /// Set max completion tokens.
53    pub fn with_max_completion_tokens(mut self, max_tokens: u32) -> Self {
54        self.max_completion_tokens = Some(max_tokens);
55        self
56    }
57
58    /// Set repetition penalty.
59    pub fn with_repetition_penalty(mut self, penalty: f64) -> Self {
60        self.repetition_penalty = Some(penalty);
61        self
62    }
63
64    /// Set top-k.
65    pub fn with_top_k(mut self, top_k: u32) -> Self {
66        self.top_k = Some(top_k);
67        self
68    }
69
70    /// Set API key.
71    pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
72        self.api_key = Some(api_key.into());
73        self
74    }
75}
76
77/// Llama API model provider implementation.
78pub struct LlamaAPIModel {
79    config: ModelConfig,
80    llamaapi_config: LlamaAPIConfig,
81}
82
83impl LlamaAPIModel {
84    /// Create a new Llama API model.
85    pub fn new(config: LlamaAPIConfig) -> Self {
86        Self {
87            config: ModelConfig::new(&config.model_id),
88            llamaapi_config: config,
89        }
90    }
91
92    /// Get the Llama API configuration.
93    pub fn llamaapi_config(&self) -> &LlamaAPIConfig {
94        &self.llamaapi_config
95    }
96
97    /// Update the Llama API configuration.
98    pub fn update_llamaapi_config(&mut self, config: LlamaAPIConfig) {
99        self.config = ModelConfig::new(&config.model_id);
100        self.llamaapi_config = config;
101    }
102}
103
104impl Model for LlamaAPIModel {
105    fn config(&self) -> &ModelConfig {
106        &self.config
107    }
108
109    fn update_config(&mut self, config: ModelConfig) {
110        self.config = config;
111    }
112
113    fn stream<'a>(
114        &'a self,
115        _messages: &'a [Message],
116        _tool_specs: Option<&'a [ToolSpec]>,
117        _system_prompt: Option<&'a str>,
118        _tool_choice: Option<ToolChoice>,
119        _system_prompt_content: Option<&'a [SystemContentBlock]>,
120    ) -> StreamEventStream<'a> {
121        Box::pin(futures::stream::once(async {
122            Err(StrandsError::ModelError {
123                message: "Llama API integration requires HTTP client implementation".into(),
124                source: None,
125            })
126        }))
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133
134    #[test]
135    fn test_llamaapi_config() {
136        let config = LlamaAPIConfig::new("Llama-4-Maverick-17B-128E-Instruct-FP8")
137            .with_temperature(0.7)
138            .with_max_completion_tokens(1000);
139        
140        assert_eq!(config.model_id, "Llama-4-Maverick-17B-128E-Instruct-FP8");
141        assert_eq!(config.temperature, Some(0.7));
142        assert_eq!(config.max_completion_tokens, Some(1000));
143    }
144
145    #[test]
146    fn test_llamaapi_model_creation() {
147        let config = LlamaAPIConfig::new("test-model");
148        let model = LlamaAPIModel::new(config);
149        
150        assert_eq!(model.config().model_id, "test-model");
151    }
152}