strands_agents/models/
writer.rs

1//! Writer model provider.
2//!
3//! This provider integrates with Writer's Palmyra models.
4//! See: https://dev.writer.com/home/introduction
5
6use std::collections::HashMap;
7
8use crate::types::content::{Message, SystemContentBlock};
9use crate::types::errors::StrandsError;
10use crate::types::tools::{ToolChoice, ToolSpec};
11
12use super::{Model, ModelConfig, StreamEventStream};
13
14/// Configuration for Writer models.
15#[derive(Debug, Clone, Default)]
16pub struct WriterConfig {
17    /// Model name to use (e.g., palmyra-x5, palmyra-x4).
18    pub model_id: String,
19    /// Maximum number of tokens to generate.
20    pub max_tokens: Option<u32>,
21    /// Stop sequences.
22    pub stop: Option<Vec<String>>,
23    /// Additional options for streaming.
24    pub stream_options: Option<HashMap<String, serde_json::Value>>,
25    /// Sampling temperature.
26    pub temperature: Option<f64>,
27    /// Top-p (nucleus sampling).
28    pub top_p: Option<f64>,
29    /// API key for authentication.
30    pub api_key: Option<String>,
31    /// Base URL for the API.
32    pub base_url: Option<String>,
33}
34
35impl WriterConfig {
36    /// Create a new Writer config.
37    pub fn new(model_id: impl Into<String>) -> Self {
38        Self {
39            model_id: model_id.into(),
40            ..Default::default()
41        }
42    }
43
44    /// Set max tokens.
45    pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
46        self.max_tokens = Some(max_tokens);
47        self
48    }
49
50    /// Set temperature.
51    pub fn with_temperature(mut self, temperature: f64) -> Self {
52        self.temperature = Some(temperature);
53        self
54    }
55
56    /// Set top-p.
57    pub fn with_top_p(mut self, top_p: f64) -> Self {
58        self.top_p = Some(top_p);
59        self
60    }
61
62    /// Set stop sequences.
63    pub fn with_stop(mut self, stop: Vec<String>) -> Self {
64        self.stop = Some(stop);
65        self
66    }
67
68    /// Set API key.
69    pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
70        self.api_key = Some(api_key.into());
71        self
72    }
73
74    /// Set base URL.
75    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
76        self.base_url = Some(base_url.into());
77        self
78    }
79}
80
81/// Writer model provider implementation.
82pub struct WriterModel {
83    config: ModelConfig,
84    writer_config: WriterConfig,
85}
86
87impl WriterModel {
88    /// Create a new Writer model.
89    pub fn new(config: WriterConfig) -> Self {
90        Self {
91            config: ModelConfig::new(&config.model_id),
92            writer_config: config,
93        }
94    }
95
96    /// Get the Writer configuration.
97    pub fn writer_config(&self) -> &WriterConfig {
98        &self.writer_config
99    }
100
101    /// Update the Writer configuration.
102    pub fn update_writer_config(&mut self, config: WriterConfig) {
103        self.config = ModelConfig::new(&config.model_id);
104        self.writer_config = config;
105    }
106
107    /// Check if this is a Palmyra X5 model (supports vision).
108    pub fn is_palmyra_x5(&self) -> bool {
109        self.writer_config.model_id == "palmyra-x5"
110    }
111}
112
113impl Model for WriterModel {
114    fn config(&self) -> &ModelConfig {
115        &self.config
116    }
117
118    fn update_config(&mut self, config: ModelConfig) {
119        self.config = config;
120    }
121
122    fn stream<'a>(
123        &'a self,
124        _messages: &'a [Message],
125        _tool_specs: Option<&'a [ToolSpec]>,
126        _system_prompt: Option<&'a str>,
127        _tool_choice: Option<ToolChoice>,
128        _system_prompt_content: Option<&'a [SystemContentBlock]>,
129    ) -> StreamEventStream<'a> {
130        Box::pin(futures::stream::once(async {
131            Err(StrandsError::ModelError {
132                message: "Writer integration requires HTTP client implementation".into(),
133                source: None,
134            })
135        }))
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use super::*;
142
143    #[test]
144    fn test_writer_config() {
145        let config = WriterConfig::new("palmyra-x5")
146            .with_temperature(0.7)
147            .with_max_tokens(1000);
148        
149        assert_eq!(config.model_id, "palmyra-x5");
150        assert_eq!(config.temperature, Some(0.7));
151        assert_eq!(config.max_tokens, Some(1000));
152    }
153
154    #[test]
155    fn test_writer_model_creation() {
156        let config = WriterConfig::new("palmyra-x4");
157        let model = WriterModel::new(config);
158        
159        assert_eq!(model.config().model_id, "palmyra-x4");
160        assert!(!model.is_palmyra_x5());
161    }
162
163    #[test]
164    fn test_palmyra_x5_detection() {
165        let config = WriterConfig::new("palmyra-x5");
166        let model = WriterModel::new(config);
167        
168        assert!(model.is_palmyra_x5());
169    }
170}