oxify_connect_llm/response_utils.rs
1//! Response post-processing utilities for LLM outputs
2//!
3//! This module provides helper functions for common response transformations,
4//! including code extraction, JSON parsing, markdown formatting, and content
5//! filtering. These utilities make it easier to work with structured LLM outputs.
6//!
7//! # Examples
8//!
9//! ```rust
10//! use oxify_connect_llm::ResponseUtils;
11//!
12//! let response = "Here's a Python example:\n\
13//! ```python\n\
14//! def hello():\n\
15//! print(\"Hello, world!\")\n\
16//! ```\n";
17//!
18//! let code_blocks = ResponseUtils::extract_code_blocks(response);
19//! assert_eq!(code_blocks.len(), 1);
20//! assert_eq!(code_blocks[0].language, Some("python".to_string()));
21//! ```
22
23use serde_json::Value;
24
25/// Code block extracted from LLM response
26#[derive(Debug, Clone, PartialEq, Eq)]
27pub struct CodeBlock {
28 /// Programming language (if specified)
29 pub language: Option<String>,
30 /// Code content
31 pub code: String,
32}
33
34/// Response post-processing utilities
35pub struct ResponseUtils;
36
37impl ResponseUtils {
38 /// Extract code blocks from markdown-formatted response
39 ///
40 /// Recognizes both ``` and ` code fence formats.
41 ///
42 /// # Examples
43 ///
44 /// ```
45 /// use oxify_connect_llm::ResponseUtils;
46 ///
47 /// let response = "```rust\nfn main() {}\n```";
48 /// let blocks = ResponseUtils::extract_code_blocks(response);
49 /// assert_eq!(blocks.len(), 1);
50 /// assert_eq!(blocks[0].language, Some("rust".to_string()));
51 /// ```
52 pub fn extract_code_blocks(response: &str) -> Vec<CodeBlock> {
53 let mut blocks = Vec::new();
54 let mut in_code_block = false;
55 let mut current_language = None;
56 let mut current_code = String::new();
57
58 for line in response.lines() {
59 if line.starts_with("```") {
60 if in_code_block {
61 // End of code block
62 blocks.push(CodeBlock {
63 language: current_language.take(),
64 code: current_code.trim().to_string(),
65 });
66 current_code.clear();
67 in_code_block = false;
68 } else {
69 // Start of code block
70 let lang = line.trim_start_matches('`').trim();
71 current_language = if lang.is_empty() {
72 None
73 } else {
74 Some(lang.to_string())
75 };
76 in_code_block = true;
77 }
78 } else if in_code_block {
79 current_code.push_str(line);
80 current_code.push('\n');
81 }
82 }
83
84 // Handle unclosed code block
85 if in_code_block && !current_code.is_empty() {
86 blocks.push(CodeBlock {
87 language: current_language,
88 code: current_code.trim().to_string(),
89 });
90 }
91
92 blocks
93 }
94
95 /// Extract first code block of a specific language
96 ///
97 /// # Examples
98 ///
99 /// ```
100 /// use oxify_connect_llm::ResponseUtils;
101 ///
102 /// let response = "```python\nprint('hello')\n```\n```rust\nprintln!(\"hi\")\n```";
103 /// let python_code = ResponseUtils::extract_code_by_language(response, "python");
104 /// assert_eq!(python_code, Some("print('hello')".to_string()));
105 /// ```
106 pub fn extract_code_by_language(response: &str, language: &str) -> Option<String> {
107 Self::extract_code_blocks(response)
108 .into_iter()
109 .find(|block| {
110 block
111 .language
112 .as_ref()
113 .map(|lang| lang.eq_ignore_ascii_case(language))
114 .unwrap_or(false)
115 })
116 .map(|block| block.code)
117 }
118
119 /// Extract all code regardless of language
120 ///
121 /// Returns concatenated code from all blocks.
122 ///
123 /// # Examples
124 ///
125 /// ```
126 /// use oxify_connect_llm::ResponseUtils;
127 ///
128 /// let response = "```\ncode1\n```\n```\ncode2\n```";
129 /// let code = ResponseUtils::extract_all_code(response);
130 /// assert!(code.contains("code1"));
131 /// assert!(code.contains("code2"));
132 /// ```
133 pub fn extract_all_code(response: &str) -> String {
134 Self::extract_code_blocks(response)
135 .into_iter()
136 .map(|block| block.code)
137 .collect::<Vec<_>>()
138 .join("\n\n")
139 }
140
141 /// Try to parse response as JSON
142 ///
143 /// Attempts to extract JSON from the response, handling cases where
144 /// the LLM wraps JSON in markdown code blocks.
145 ///
146 /// # Examples
147 ///
148 /// ```
149 /// use oxify_connect_llm::ResponseUtils;
150 ///
151 /// let response = r#"```json
152 /// {"name": "Alice", "age": 30}
153 /// ```"#;
154 /// let json = ResponseUtils::parse_json(response);
155 /// assert!(json.is_ok());
156 /// ```
157 pub fn parse_json(response: &str) -> Result<Value, serde_json::Error> {
158 // Try parsing as-is first
159 if let Ok(value) = serde_json::from_str(response.trim()) {
160 return Ok(value);
161 }
162
163 // Try extracting from JSON code block
164 if let Some(json_code) = Self::extract_code_by_language(response, "json") {
165 return serde_json::from_str(&json_code);
166 }
167
168 // Try first code block (might be unlabeled JSON)
169 if let Some(first_block) = Self::extract_code_blocks(response).first() {
170 if let Ok(value) = serde_json::from_str(&first_block.code) {
171 return Ok(value);
172 }
173 }
174
175 // Last resort: try parsing the whole response
176 serde_json::from_str(response.trim())
177 }
178
179 /// Remove markdown formatting from response
180 ///
181 /// Strips common markdown elements like headers, bold, italic, etc.
182 ///
183 /// # Examples
184 ///
185 /// ```
186 /// use oxify_connect_llm::ResponseUtils;
187 ///
188 /// let response = "# Title\n**Bold** and *italic*";
189 /// let plain = ResponseUtils::strip_markdown(response);
190 /// assert!(!plain.contains('*'));
191 /// assert!(!plain.contains('#'));
192 /// ```
193 pub fn strip_markdown(response: &str) -> String {
194 let mut result = response.to_string();
195
196 // Remove headers
197 result = result
198 .lines()
199 .map(|line| line.trim_start_matches('#').trim())
200 .collect::<Vec<_>>()
201 .join("\n");
202
203 // Remove bold and italic
204 result = result.replace("**", "");
205 result = result.replace("__", "");
206 result = result.replace('*', "");
207 result = result.replace('_', "");
208
209 // Remove inline code
210 result = result.replace('`', "");
211
212 result.trim().to_string()
213 }
214
215 /// Extract numbered list items from response
216 ///
217 /// # Examples
218 ///
219 /// ```
220 /// use oxify_connect_llm::ResponseUtils;
221 ///
222 /// let response = "1. First\n2. Second\n3. Third";
223 /// let items = ResponseUtils::extract_numbered_list(response);
224 /// assert_eq!(items, vec!["First", "Second", "Third"]);
225 /// ```
226 pub fn extract_numbered_list(response: &str) -> Vec<String> {
227 response
228 .lines()
229 .filter_map(|line| {
230 let trimmed = line.trim();
231 // Match patterns like "1. ", "2) ", etc.
232 if let Some(pos) = trimmed.find(['.', ')']) {
233 let prefix = &trimmed[..pos];
234 if prefix.chars().all(|c| c.is_ascii_digit()) {
235 let content = trimmed[pos + 1..].trim();
236 if !content.is_empty() {
237 return Some(content.to_string());
238 }
239 }
240 }
241 None
242 })
243 .collect()
244 }
245
246 /// Extract bullet list items from response
247 ///
248 /// # Examples
249 ///
250 /// ```
251 /// use oxify_connect_llm::ResponseUtils;
252 ///
253 /// let response = "- First\n* Second\n- Third";
254 /// let items = ResponseUtils::extract_bullet_list(response);
255 /// assert_eq!(items.len(), 3);
256 /// ```
257 pub fn extract_bullet_list(response: &str) -> Vec<String> {
258 response
259 .lines()
260 .filter_map(|line| {
261 let trimmed = line.trim();
262 if trimmed.starts_with('-') || trimmed.starts_with('*') {
263 let content = trimmed[1..].trim();
264 if !content.is_empty() {
265 return Some(content.to_string());
266 }
267 }
268 None
269 })
270 .collect()
271 }
272
273 /// Truncate response to a maximum length
274 ///
275 /// Tries to truncate at sentence boundary if possible.
276 ///
277 /// # Examples
278 ///
279 /// ```
280 /// use oxify_connect_llm::ResponseUtils;
281 ///
282 /// let response = "First sentence. Second sentence. Third sentence.";
283 /// let truncated = ResponseUtils::truncate(response, 20);
284 /// assert!(truncated.len() <= 23); // 20 + "..."
285 /// ```
286 pub fn truncate(response: &str, max_length: usize) -> String {
287 if response.len() <= max_length {
288 return response.to_string();
289 }
290
291 // Try to find sentence boundary
292 if let Some(pos) = response[..max_length].rfind(['.', '!', '?']) {
293 return format!("{}...", &response[..=pos]);
294 }
295
296 // Fall back to word boundary
297 if let Some(pos) = response[..max_length].rfind(' ') {
298 return format!("{}...", &response[..pos]);
299 }
300
301 // Last resort: hard truncate
302 format!("{}...", &response[..max_length])
303 }
304
305 /// Extract URLs from response
306 ///
307 /// # Examples
308 ///
309 /// ```
310 /// use oxify_connect_llm::ResponseUtils;
311 ///
312 /// let response = "Check out https://example.com and http://test.org";
313 /// let urls = ResponseUtils::extract_urls(response);
314 /// assert_eq!(urls.len(), 2);
315 /// ```
316 pub fn extract_urls(response: &str) -> Vec<String> {
317 let mut urls = Vec::new();
318 for word in response.split_whitespace() {
319 if word.starts_with("http://") || word.starts_with("https://") {
320 // Clean up trailing punctuation
321 let cleaned = word.trim_end_matches(|c: char| !c.is_alphanumeric() && c != '/');
322 urls.push(cleaned.to_string());
323 }
324 }
325 urls
326 }
327
328 /// Count sentences in response
329 ///
330 /// # Examples
331 ///
332 /// ```
333 /// use oxify_connect_llm::ResponseUtils;
334 ///
335 /// let response = "First sentence. Second sentence! Third sentence?";
336 /// assert_eq!(ResponseUtils::count_sentences(response), 3);
337 /// ```
338 pub fn count_sentences(response: &str) -> usize {
339 response
340 .chars()
341 .filter(|c| *c == '.' || *c == '!' || *c == '?')
342 .count()
343 }
344
345 /// Count words in response
346 ///
347 /// # Examples
348 ///
349 /// ```
350 /// use oxify_connect_llm::ResponseUtils;
351 ///
352 /// let response = "This is a test response";
353 /// assert_eq!(ResponseUtils::count_words(response), 5);
354 /// ```
355 pub fn count_words(response: &str) -> usize {
356 response.split_whitespace().count()
357 }
358
359 /// Remove extra whitespace from response
360 ///
361 /// # Examples
362 ///
363 /// ```
364 /// use oxify_connect_llm::ResponseUtils;
365 ///
366 /// let response = "Too many spaces";
367 /// let normalized = ResponseUtils::normalize_whitespace(response);
368 /// assert_eq!(normalized, "Too many spaces");
369 /// ```
370 pub fn normalize_whitespace(response: &str) -> String {
371 response.split_whitespace().collect::<Vec<_>>().join(" ")
372 }
373}
374
375#[cfg(test)]
376mod tests {
377 use super::*;
378
379 #[test]
380 fn test_extract_code_blocks() {
381 let response = r#"
382Here's some code:
383```rust
384fn main() {
385 println!("Hello");
386}
387```
388
389And Python:
390```python
391print("World")
392```
393"#;
394
395 let blocks = ResponseUtils::extract_code_blocks(response);
396 assert_eq!(blocks.len(), 2);
397 assert_eq!(blocks[0].language, Some("rust".to_string()));
398 assert_eq!(blocks[1].language, Some("python".to_string()));
399 assert!(blocks[0].code.contains("fn main"));
400 assert!(blocks[1].code.contains("print"));
401 }
402
403 #[test]
404 fn test_extract_code_by_language() {
405 let response = "```rust\nlet x = 5;\n```\n```python\ny = 10\n```";
406 let rust_code = ResponseUtils::extract_code_by_language(response, "rust");
407 assert_eq!(rust_code, Some("let x = 5;".to_string()));
408
409 let python_code = ResponseUtils::extract_code_by_language(response, "python");
410 assert_eq!(python_code, Some("y = 10".to_string()));
411
412 let js_code = ResponseUtils::extract_code_by_language(response, "javascript");
413 assert_eq!(js_code, None);
414 }
415
416 #[test]
417 fn test_parse_json() {
418 let response = r#"```json
419{
420 "name": "Alice",
421 "age": 30
422}
423```"#;
424
425 let json = ResponseUtils::parse_json(response).unwrap();
426 assert_eq!(json["name"], "Alice");
427 assert_eq!(json["age"], 30);
428 }
429
430 #[test]
431 fn test_parse_json_direct() {
432 let response = r#"{"name": "Bob", "age": 25}"#;
433 let json = ResponseUtils::parse_json(response).unwrap();
434 assert_eq!(json["name"], "Bob");
435 }
436
437 #[test]
438 fn test_strip_markdown() {
439 let response = "# Title\n**Bold** and *italic* text";
440 let plain = ResponseUtils::strip_markdown(response);
441 assert_eq!(plain, "Title\nBold and italic text");
442 }
443
444 #[test]
445 fn test_extract_numbered_list() {
446 let response = "1. First\n2. Second\n3. Third";
447 let items = ResponseUtils::extract_numbered_list(response);
448 assert_eq!(items, vec!["First", "Second", "Third"]);
449 }
450
451 #[test]
452 fn test_extract_bullet_list() {
453 let response = "- Apple\n* Banana\n- Cherry";
454 let items = ResponseUtils::extract_bullet_list(response);
455 assert_eq!(items, vec!["Apple", "Banana", "Cherry"]);
456 }
457
458 #[test]
459 fn test_truncate() {
460 let response = "This is a long sentence. This is another sentence.";
461 let truncated = ResponseUtils::truncate(response, 25);
462 assert!(truncated.len() <= 28); // 25 + "..."
463 assert!(truncated.ends_with("..."));
464 }
465
466 #[test]
467 fn test_extract_urls() {
468 let response = "Visit https://example.com and http://test.org for more info.";
469 let urls = ResponseUtils::extract_urls(response);
470 assert_eq!(urls.len(), 2);
471 assert!(urls.contains(&"https://example.com".to_string()));
472 assert!(urls.contains(&"http://test.org".to_string()));
473 }
474
475 #[test]
476 fn test_count_sentences() {
477 let response = "First. Second! Third?";
478 assert_eq!(ResponseUtils::count_sentences(response), 3);
479 }
480
481 #[test]
482 fn test_count_words() {
483 let response = "This is a test";
484 assert_eq!(ResponseUtils::count_words(response), 4);
485 }
486
487 #[test]
488 fn test_normalize_whitespace() {
489 let response = "Too many spaces";
490 let normalized = ResponseUtils::normalize_whitespace(response);
491 assert_eq!(normalized, "Too many spaces");
492 }
493
494 #[test]
495 fn test_extract_all_code() {
496 let response = "```\ncode1\n```\nSome text\n```\ncode2\n```";
497 let code = ResponseUtils::extract_all_code(response);
498 assert!(code.contains("code1"));
499 assert!(code.contains("code2"));
500 }
501}