strands_agents/models/
writer.rs1use std::collections::HashMap;
7
8use crate::types::content::{Message, SystemContentBlock};
9use crate::types::errors::StrandsError;
10use crate::types::tools::{ToolChoice, ToolSpec};
11
12use super::{Model, ModelConfig, StreamEventStream};
13
14#[derive(Debug, Clone, Default)]
16pub struct WriterConfig {
17 pub model_id: String,
19 pub max_tokens: Option<u32>,
21 pub stop: Option<Vec<String>>,
23 pub stream_options: Option<HashMap<String, serde_json::Value>>,
25 pub temperature: Option<f64>,
27 pub top_p: Option<f64>,
29 pub api_key: Option<String>,
31 pub base_url: Option<String>,
33}
34
35impl WriterConfig {
36 pub fn new(model_id: impl Into<String>) -> Self {
38 Self {
39 model_id: model_id.into(),
40 ..Default::default()
41 }
42 }
43
44 pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
46 self.max_tokens = Some(max_tokens);
47 self
48 }
49
50 pub fn with_temperature(mut self, temperature: f64) -> Self {
52 self.temperature = Some(temperature);
53 self
54 }
55
56 pub fn with_top_p(mut self, top_p: f64) -> Self {
58 self.top_p = Some(top_p);
59 self
60 }
61
62 pub fn with_stop(mut self, stop: Vec<String>) -> Self {
64 self.stop = Some(stop);
65 self
66 }
67
68 pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
70 self.api_key = Some(api_key.into());
71 self
72 }
73
74 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
76 self.base_url = Some(base_url.into());
77 self
78 }
79}
80
81pub struct WriterModel {
83 config: ModelConfig,
84 writer_config: WriterConfig,
85}
86
87impl WriterModel {
88 pub fn new(config: WriterConfig) -> Self {
90 Self {
91 config: ModelConfig::new(&config.model_id),
92 writer_config: config,
93 }
94 }
95
96 pub fn writer_config(&self) -> &WriterConfig {
98 &self.writer_config
99 }
100
101 pub fn update_writer_config(&mut self, config: WriterConfig) {
103 self.config = ModelConfig::new(&config.model_id);
104 self.writer_config = config;
105 }
106
107 pub fn is_palmyra_x5(&self) -> bool {
109 self.writer_config.model_id == "palmyra-x5"
110 }
111}
112
113impl Model for WriterModel {
114 fn config(&self) -> &ModelConfig {
115 &self.config
116 }
117
118 fn update_config(&mut self, config: ModelConfig) {
119 self.config = config;
120 }
121
122 fn stream<'a>(
123 &'a self,
124 _messages: &'a [Message],
125 _tool_specs: Option<&'a [ToolSpec]>,
126 _system_prompt: Option<&'a str>,
127 _tool_choice: Option<ToolChoice>,
128 _system_prompt_content: Option<&'a [SystemContentBlock]>,
129 ) -> StreamEventStream<'a> {
130 Box::pin(futures::stream::once(async {
131 Err(StrandsError::ModelError {
132 message: "Writer integration requires HTTP client implementation".into(),
133 source: None,
134 })
135 }))
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142
143 #[test]
144 fn test_writer_config() {
145 let config = WriterConfig::new("palmyra-x5")
146 .with_temperature(0.7)
147 .with_max_tokens(1000);
148
149 assert_eq!(config.model_id, "palmyra-x5");
150 assert_eq!(config.temperature, Some(0.7));
151 assert_eq!(config.max_tokens, Some(1000));
152 }
153
154 #[test]
155 fn test_writer_model_creation() {
156 let config = WriterConfig::new("palmyra-x4");
157 let model = WriterModel::new(config);
158
159 assert_eq!(model.config().model_id, "palmyra-x4");
160 assert!(!model.is_palmyra_x5());
161 }
162
163 #[test]
164 fn test_palmyra_x5_detection() {
165 let config = WriterConfig::new("palmyra-x5");
166 let model = WriterModel::new(config);
167
168 assert!(model.is_palmyra_x5());
169 }
170}