oxify_connect_llm/
response_utils.rs

1//! Response post-processing utilities for LLM outputs
2//!
3//! This module provides helper functions for common response transformations,
4//! including code extraction, JSON parsing, markdown formatting, and content
5//! filtering. These utilities make it easier to work with structured LLM outputs.
6//!
7//! # Examples
8//!
9//! ```rust
10//! use oxify_connect_llm::ResponseUtils;
11//!
12//! let response = "Here's a Python example:\n\
13//!     ```python\n\
14//!     def hello():\n\
15//!         print(\"Hello, world!\")\n\
16//!     ```\n";
17//!
18//! let code_blocks = ResponseUtils::extract_code_blocks(response);
19//! assert_eq!(code_blocks.len(), 1);
20//! assert_eq!(code_blocks[0].language, Some("python".to_string()));
21//! ```
22
23use serde_json::Value;
24
25/// Code block extracted from LLM response
26#[derive(Debug, Clone, PartialEq, Eq)]
27pub struct CodeBlock {
28    /// Programming language (if specified)
29    pub language: Option<String>,
30    /// Code content
31    pub code: String,
32}
33
34/// Response post-processing utilities
35pub struct ResponseUtils;
36
37impl ResponseUtils {
38    /// Extract code blocks from markdown-formatted response
39    ///
40    /// Recognizes both ``` and ` code fence formats.
41    ///
42    /// # Examples
43    ///
44    /// ```
45    /// use oxify_connect_llm::ResponseUtils;
46    ///
47    /// let response = "```rust\nfn main() {}\n```";
48    /// let blocks = ResponseUtils::extract_code_blocks(response);
49    /// assert_eq!(blocks.len(), 1);
50    /// assert_eq!(blocks[0].language, Some("rust".to_string()));
51    /// ```
52    pub fn extract_code_blocks(response: &str) -> Vec<CodeBlock> {
53        let mut blocks = Vec::new();
54        let mut in_code_block = false;
55        let mut current_language = None;
56        let mut current_code = String::new();
57
58        for line in response.lines() {
59            if line.starts_with("```") {
60                if in_code_block {
61                    // End of code block
62                    blocks.push(CodeBlock {
63                        language: current_language.take(),
64                        code: current_code.trim().to_string(),
65                    });
66                    current_code.clear();
67                    in_code_block = false;
68                } else {
69                    // Start of code block
70                    let lang = line.trim_start_matches('`').trim();
71                    current_language = if lang.is_empty() {
72                        None
73                    } else {
74                        Some(lang.to_string())
75                    };
76                    in_code_block = true;
77                }
78            } else if in_code_block {
79                current_code.push_str(line);
80                current_code.push('\n');
81            }
82        }
83
84        // Handle unclosed code block
85        if in_code_block && !current_code.is_empty() {
86            blocks.push(CodeBlock {
87                language: current_language,
88                code: current_code.trim().to_string(),
89            });
90        }
91
92        blocks
93    }
94
95    /// Extract first code block of a specific language
96    ///
97    /// # Examples
98    ///
99    /// ```
100    /// use oxify_connect_llm::ResponseUtils;
101    ///
102    /// let response = "```python\nprint('hello')\n```\n```rust\nprintln!(\"hi\")\n```";
103    /// let python_code = ResponseUtils::extract_code_by_language(response, "python");
104    /// assert_eq!(python_code, Some("print('hello')".to_string()));
105    /// ```
106    pub fn extract_code_by_language(response: &str, language: &str) -> Option<String> {
107        Self::extract_code_blocks(response)
108            .into_iter()
109            .find(|block| {
110                block
111                    .language
112                    .as_ref()
113                    .map(|lang| lang.eq_ignore_ascii_case(language))
114                    .unwrap_or(false)
115            })
116            .map(|block| block.code)
117    }
118
119    /// Extract all code regardless of language
120    ///
121    /// Returns concatenated code from all blocks.
122    ///
123    /// # Examples
124    ///
125    /// ```
126    /// use oxify_connect_llm::ResponseUtils;
127    ///
128    /// let response = "```\ncode1\n```\n```\ncode2\n```";
129    /// let code = ResponseUtils::extract_all_code(response);
130    /// assert!(code.contains("code1"));
131    /// assert!(code.contains("code2"));
132    /// ```
133    pub fn extract_all_code(response: &str) -> String {
134        Self::extract_code_blocks(response)
135            .into_iter()
136            .map(|block| block.code)
137            .collect::<Vec<_>>()
138            .join("\n\n")
139    }
140
141    /// Try to parse response as JSON
142    ///
143    /// Attempts to extract JSON from the response, handling cases where
144    /// the LLM wraps JSON in markdown code blocks.
145    ///
146    /// # Examples
147    ///
148    /// ```
149    /// use oxify_connect_llm::ResponseUtils;
150    ///
151    /// let response = r#"```json
152    /// {"name": "Alice", "age": 30}
153    /// ```"#;
154    /// let json = ResponseUtils::parse_json(response);
155    /// assert!(json.is_ok());
156    /// ```
157    pub fn parse_json(response: &str) -> Result<Value, serde_json::Error> {
158        // Try parsing as-is first
159        if let Ok(value) = serde_json::from_str(response.trim()) {
160            return Ok(value);
161        }
162
163        // Try extracting from JSON code block
164        if let Some(json_code) = Self::extract_code_by_language(response, "json") {
165            return serde_json::from_str(&json_code);
166        }
167
168        // Try first code block (might be unlabeled JSON)
169        if let Some(first_block) = Self::extract_code_blocks(response).first() {
170            if let Ok(value) = serde_json::from_str(&first_block.code) {
171                return Ok(value);
172            }
173        }
174
175        // Last resort: try parsing the whole response
176        serde_json::from_str(response.trim())
177    }
178
179    /// Remove markdown formatting from response
180    ///
181    /// Strips common markdown elements like headers, bold, italic, etc.
182    ///
183    /// # Examples
184    ///
185    /// ```
186    /// use oxify_connect_llm::ResponseUtils;
187    ///
188    /// let response = "# Title\n**Bold** and *italic*";
189    /// let plain = ResponseUtils::strip_markdown(response);
190    /// assert!(!plain.contains('*'));
191    /// assert!(!plain.contains('#'));
192    /// ```
193    pub fn strip_markdown(response: &str) -> String {
194        let mut result = response.to_string();
195
196        // Remove headers
197        result = result
198            .lines()
199            .map(|line| line.trim_start_matches('#').trim())
200            .collect::<Vec<_>>()
201            .join("\n");
202
203        // Remove bold and italic
204        result = result.replace("**", "");
205        result = result.replace("__", "");
206        result = result.replace('*', "");
207        result = result.replace('_', "");
208
209        // Remove inline code
210        result = result.replace('`', "");
211
212        result.trim().to_string()
213    }
214
215    /// Extract numbered list items from response
216    ///
217    /// # Examples
218    ///
219    /// ```
220    /// use oxify_connect_llm::ResponseUtils;
221    ///
222    /// let response = "1. First\n2. Second\n3. Third";
223    /// let items = ResponseUtils::extract_numbered_list(response);
224    /// assert_eq!(items, vec!["First", "Second", "Third"]);
225    /// ```
226    pub fn extract_numbered_list(response: &str) -> Vec<String> {
227        response
228            .lines()
229            .filter_map(|line| {
230                let trimmed = line.trim();
231                // Match patterns like "1. ", "2) ", etc.
232                if let Some(pos) = trimmed.find(['.', ')']) {
233                    let prefix = &trimmed[..pos];
234                    if prefix.chars().all(|c| c.is_ascii_digit()) {
235                        let content = trimmed[pos + 1..].trim();
236                        if !content.is_empty() {
237                            return Some(content.to_string());
238                        }
239                    }
240                }
241                None
242            })
243            .collect()
244    }
245
246    /// Extract bullet list items from response
247    ///
248    /// # Examples
249    ///
250    /// ```
251    /// use oxify_connect_llm::ResponseUtils;
252    ///
253    /// let response = "- First\n* Second\n- Third";
254    /// let items = ResponseUtils::extract_bullet_list(response);
255    /// assert_eq!(items.len(), 3);
256    /// ```
257    pub fn extract_bullet_list(response: &str) -> Vec<String> {
258        response
259            .lines()
260            .filter_map(|line| {
261                let trimmed = line.trim();
262                if trimmed.starts_with('-') || trimmed.starts_with('*') {
263                    let content = trimmed[1..].trim();
264                    if !content.is_empty() {
265                        return Some(content.to_string());
266                    }
267                }
268                None
269            })
270            .collect()
271    }
272
273    /// Truncate response to a maximum length
274    ///
275    /// Tries to truncate at sentence boundary if possible.
276    ///
277    /// # Examples
278    ///
279    /// ```
280    /// use oxify_connect_llm::ResponseUtils;
281    ///
282    /// let response = "First sentence. Second sentence. Third sentence.";
283    /// let truncated = ResponseUtils::truncate(response, 20);
284    /// assert!(truncated.len() <= 23); // 20 + "..."
285    /// ```
286    pub fn truncate(response: &str, max_length: usize) -> String {
287        if response.len() <= max_length {
288            return response.to_string();
289        }
290
291        // Try to find sentence boundary
292        if let Some(pos) = response[..max_length].rfind(['.', '!', '?']) {
293            return format!("{}...", &response[..=pos]);
294        }
295
296        // Fall back to word boundary
297        if let Some(pos) = response[..max_length].rfind(' ') {
298            return format!("{}...", &response[..pos]);
299        }
300
301        // Last resort: hard truncate
302        format!("{}...", &response[..max_length])
303    }
304
305    /// Extract URLs from response
306    ///
307    /// # Examples
308    ///
309    /// ```
310    /// use oxify_connect_llm::ResponseUtils;
311    ///
312    /// let response = "Check out https://example.com and http://test.org";
313    /// let urls = ResponseUtils::extract_urls(response);
314    /// assert_eq!(urls.len(), 2);
315    /// ```
316    pub fn extract_urls(response: &str) -> Vec<String> {
317        let mut urls = Vec::new();
318        for word in response.split_whitespace() {
319            if word.starts_with("http://") || word.starts_with("https://") {
320                // Clean up trailing punctuation
321                let cleaned = word.trim_end_matches(|c: char| !c.is_alphanumeric() && c != '/');
322                urls.push(cleaned.to_string());
323            }
324        }
325        urls
326    }
327
328    /// Count sentences in response
329    ///
330    /// # Examples
331    ///
332    /// ```
333    /// use oxify_connect_llm::ResponseUtils;
334    ///
335    /// let response = "First sentence. Second sentence! Third sentence?";
336    /// assert_eq!(ResponseUtils::count_sentences(response), 3);
337    /// ```
338    pub fn count_sentences(response: &str) -> usize {
339        response
340            .chars()
341            .filter(|c| *c == '.' || *c == '!' || *c == '?')
342            .count()
343    }
344
345    /// Count words in response
346    ///
347    /// # Examples
348    ///
349    /// ```
350    /// use oxify_connect_llm::ResponseUtils;
351    ///
352    /// let response = "This is a test response";
353    /// assert_eq!(ResponseUtils::count_words(response), 5);
354    /// ```
355    pub fn count_words(response: &str) -> usize {
356        response.split_whitespace().count()
357    }
358
359    /// Remove extra whitespace from response
360    ///
361    /// # Examples
362    ///
363    /// ```
364    /// use oxify_connect_llm::ResponseUtils;
365    ///
366    /// let response = "Too   many    spaces";
367    /// let normalized = ResponseUtils::normalize_whitespace(response);
368    /// assert_eq!(normalized, "Too many spaces");
369    /// ```
370    pub fn normalize_whitespace(response: &str) -> String {
371        response.split_whitespace().collect::<Vec<_>>().join(" ")
372    }
373}
374
375#[cfg(test)]
376mod tests {
377    use super::*;
378
379    #[test]
380    fn test_extract_code_blocks() {
381        let response = r#"
382Here's some code:
383```rust
384fn main() {
385    println!("Hello");
386}
387```
388
389And Python:
390```python
391print("World")
392```
393"#;
394
395        let blocks = ResponseUtils::extract_code_blocks(response);
396        assert_eq!(blocks.len(), 2);
397        assert_eq!(blocks[0].language, Some("rust".to_string()));
398        assert_eq!(blocks[1].language, Some("python".to_string()));
399        assert!(blocks[0].code.contains("fn main"));
400        assert!(blocks[1].code.contains("print"));
401    }
402
403    #[test]
404    fn test_extract_code_by_language() {
405        let response = "```rust\nlet x = 5;\n```\n```python\ny = 10\n```";
406        let rust_code = ResponseUtils::extract_code_by_language(response, "rust");
407        assert_eq!(rust_code, Some("let x = 5;".to_string()));
408
409        let python_code = ResponseUtils::extract_code_by_language(response, "python");
410        assert_eq!(python_code, Some("y = 10".to_string()));
411
412        let js_code = ResponseUtils::extract_code_by_language(response, "javascript");
413        assert_eq!(js_code, None);
414    }
415
416    #[test]
417    fn test_parse_json() {
418        let response = r#"```json
419{
420  "name": "Alice",
421  "age": 30
422}
423```"#;
424
425        let json = ResponseUtils::parse_json(response).unwrap();
426        assert_eq!(json["name"], "Alice");
427        assert_eq!(json["age"], 30);
428    }
429
430    #[test]
431    fn test_parse_json_direct() {
432        let response = r#"{"name": "Bob", "age": 25}"#;
433        let json = ResponseUtils::parse_json(response).unwrap();
434        assert_eq!(json["name"], "Bob");
435    }
436
437    #[test]
438    fn test_strip_markdown() {
439        let response = "# Title\n**Bold** and *italic* text";
440        let plain = ResponseUtils::strip_markdown(response);
441        assert_eq!(plain, "Title\nBold and italic text");
442    }
443
444    #[test]
445    fn test_extract_numbered_list() {
446        let response = "1. First\n2. Second\n3. Third";
447        let items = ResponseUtils::extract_numbered_list(response);
448        assert_eq!(items, vec!["First", "Second", "Third"]);
449    }
450
451    #[test]
452    fn test_extract_bullet_list() {
453        let response = "- Apple\n* Banana\n- Cherry";
454        let items = ResponseUtils::extract_bullet_list(response);
455        assert_eq!(items, vec!["Apple", "Banana", "Cherry"]);
456    }
457
458    #[test]
459    fn test_truncate() {
460        let response = "This is a long sentence. This is another sentence.";
461        let truncated = ResponseUtils::truncate(response, 25);
462        assert!(truncated.len() <= 28); // 25 + "..."
463        assert!(truncated.ends_with("..."));
464    }
465
466    #[test]
467    fn test_extract_urls() {
468        let response = "Visit https://example.com and http://test.org for more info.";
469        let urls = ResponseUtils::extract_urls(response);
470        assert_eq!(urls.len(), 2);
471        assert!(urls.contains(&"https://example.com".to_string()));
472        assert!(urls.contains(&"http://test.org".to_string()));
473    }
474
475    #[test]
476    fn test_count_sentences() {
477        let response = "First. Second! Third?";
478        assert_eq!(ResponseUtils::count_sentences(response), 3);
479    }
480
481    #[test]
482    fn test_count_words() {
483        let response = "This is a test";
484        assert_eq!(ResponseUtils::count_words(response), 4);
485    }
486
487    #[test]
488    fn test_normalize_whitespace() {
489        let response = "Too   many    spaces";
490        let normalized = ResponseUtils::normalize_whitespace(response);
491        assert_eq!(normalized, "Too many spaces");
492    }
493
494    #[test]
495    fn test_extract_all_code() {
496        let response = "```\ncode1\n```\nSome text\n```\ncode2\n```";
497        let code = ResponseUtils::extract_all_code(response);
498        assert!(code.contains("code1"));
499        assert!(code.contains("code2"));
500    }
501}