1use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
10#[serde(rename_all = "snake_case")]
11pub enum ReplyFormat {
12 Markdown,
14 PlainText,
16 Html,
18 Json,
20 Code(String),
22 Table,
24 Bullet,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
30#[serde(rename_all = "snake_case")]
31pub enum ReplyTone {
32 Professional,
34 Casual,
36 Technical,
38 Friendly,
40 Concise,
42 Detailed,
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct ReplyDirective {
49 pub format: ReplyFormat,
51 pub max_length: Option<usize>,
53 pub tone: Option<ReplyTone>,
55 pub language: Option<String>,
57 pub audience: Option<String>,
59 pub include_sources: bool,
61 pub structured_output: Option<serde_json::Value>,
63}
64
65impl ReplyDirective {
66 pub fn new(format: ReplyFormat) -> Self {
68 Self {
69 format,
70 max_length: None,
71 tone: None,
72 language: None,
73 audience: None,
74 include_sources: false,
75 structured_output: None,
76 }
77 }
78
79 pub fn with_max_length(mut self, max_length: usize) -> Self {
81 self.max_length = Some(max_length);
82 self
83 }
84
85 pub fn with_tone(mut self, tone: ReplyTone) -> Self {
87 self.tone = Some(tone);
88 self
89 }
90
91 pub fn with_language(mut self, language: impl Into<String>) -> Self {
93 self.language = Some(language.into());
94 self
95 }
96
97 pub fn with_audience(mut self, audience: impl Into<String>) -> Self {
99 self.audience = Some(audience.into());
100 self
101 }
102
103 pub fn with_sources(mut self) -> Self {
105 self.include_sources = true;
106 self
107 }
108
109 pub fn with_structured_output(mut self, schema: serde_json::Value) -> Self {
111 self.structured_output = Some(schema);
112 self
113 }
114}
115
116pub fn apply_directive(content: &str, directive: &ReplyDirective) -> String {
123 let mut result = match &directive.format {
124 ReplyFormat::Code(language) => {
125 format!("```{language}\n{content}\n```")
126 }
127 ReplyFormat::PlainText => strip_markdown(content),
128 _ => content.to_string(),
129 };
130
131 if let Some(max_len) = directive.max_length
132 && result.len() > max_len
133 {
134 result.truncate(max_len);
135 if let Some(last_space) = result.rfind(' ') {
137 result.truncate(last_space);
138 }
139 result.push_str("...");
140 }
141
142 result
143}
144
145fn strip_markdown(text: &str) -> String {
147 let mut result = String::with_capacity(text.len());
148
149 for line in text.lines() {
150 let stripped = line.trim();
151
152 if stripped.starts_with('#') {
154 let without_hashes = stripped.trim_start_matches('#').trim_start();
155 result.push_str(without_hashes);
156 }
157 else if stripped.starts_with("- ") || stripped.starts_with("* ") {
159 result.push_str(&stripped[2..]);
160 }
161 else if stripped.chars().take_while(|c| c.is_ascii_digit()).count() > 0
163 && stripped.contains(". ")
164 {
165 if let Some(pos) = stripped.find(". ") {
166 let prefix = &stripped[..pos];
167 if prefix.chars().all(|c| c.is_ascii_digit()) {
168 result.push_str(&stripped[pos + 2..]);
169 } else {
170 result.push_str(stripped);
171 }
172 } else {
173 result.push_str(stripped);
174 }
175 } else {
176 result.push_str(stripped);
177 }
178 result.push('\n');
179 }
180
181 let result = result
183 .replace("**", "")
184 .replace("__", "")
185 .replace(['*', '_', '`'], "");
186
187 result.trim().to_string()
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193
194 #[test]
195 fn test_directive_creation() {
196 let directive = ReplyDirective::new(ReplyFormat::Markdown)
197 .with_tone(ReplyTone::Professional)
198 .with_max_length(500)
199 .with_language("en")
200 .with_audience("developers")
201 .with_sources();
202
203 assert_eq!(directive.format, ReplyFormat::Markdown);
204 assert_eq!(directive.tone, Some(ReplyTone::Professional));
205 assert_eq!(directive.max_length, Some(500));
206 assert_eq!(directive.language, Some("en".to_string()));
207 assert_eq!(directive.audience, Some("developers".to_string()));
208 assert!(directive.include_sources);
209 }
210
211 #[test]
212 fn test_apply_truncation() {
213 let directive = ReplyDirective::new(ReplyFormat::Markdown).with_max_length(20);
214
215 let content = "This is a long piece of content that should be truncated";
216 let result = apply_directive(content, &directive);
217
218 assert!(result.ends_with("..."));
219 assert!(result.len() <= 20 + 3); }
222
223 #[test]
224 fn test_apply_code_format() {
225 let directive = ReplyDirective::new(ReplyFormat::Code("rust".to_string()));
226
227 let content = "fn main() {}";
228 let result = apply_directive(content, &directive);
229
230 assert!(result.starts_with("```rust\n"));
231 assert!(result.ends_with("\n```"));
232 assert!(result.contains("fn main() {}"));
233 }
234
235 #[test]
236 fn test_apply_plain_text() {
237 let directive = ReplyDirective::new(ReplyFormat::PlainText);
238
239 let content = "# Heading\n\n**Bold text** and *italic text*\n\n- Item one\n- Item two";
240 let result = apply_directive(content, &directive);
241
242 assert!(!result.contains('#'));
243 assert!(!result.contains("**"));
244 assert!(!result.contains('*'));
245 assert!(result.contains("Heading"));
246 assert!(result.contains("Bold text"));
247 }
248
249 #[test]
250 fn test_format_serialization() {
251 let formats = vec![
252 ReplyFormat::Markdown,
253 ReplyFormat::PlainText,
254 ReplyFormat::Html,
255 ReplyFormat::Json,
256 ReplyFormat::Code("python".to_string()),
257 ReplyFormat::Table,
258 ReplyFormat::Bullet,
259 ];
260
261 for fmt in &formats {
262 let json = serde_json::to_string(fmt).expect("serialize format");
263 let deser: ReplyFormat = serde_json::from_str(&json).expect("deserialize format");
264 assert_eq!(&deser, fmt);
265 }
266
267 let tones = vec![
268 ReplyTone::Professional,
269 ReplyTone::Casual,
270 ReplyTone::Technical,
271 ReplyTone::Friendly,
272 ReplyTone::Concise,
273 ReplyTone::Detailed,
274 ];
275
276 for tone in &tones {
277 let json = serde_json::to_string(tone).expect("serialize tone");
278 let deser: ReplyTone = serde_json::from_str(&json).expect("deserialize tone");
279 assert_eq!(&deser, tone);
280 }
281 }
282}