agent_chain_core/language_models/
utils.rs

1//! Utility functions for language models.
2//!
3//! This module contains helper functions for working with language models,
4//! including message normalization and content block utilities.
5//! Mirrors `langchain_core.language_models._utils`.
6
7use std::collections::HashMap;
8
9use regex::Regex;
10use serde::{Deserialize, Serialize};
11
12/// Filter type for OpenAI data blocks.
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum DataBlockFilter {
15    /// Only match image blocks.
16    Image,
17    /// Only match audio blocks.
18    Audio,
19    /// Only match file blocks.
20    File,
21}
22
23/// Check whether a block contains multimodal data in OpenAI Chat Completions format.
24///
25/// Supports both data and ID-style blocks (e.g. `'file_data'` and `'file_id'`)
26///
27/// # Arguments
28///
29/// * `block` - The content block to check.
30/// * `filter` - If provided, only return true for blocks matching this specific type.
31///
32/// # Returns
33///
34/// `true` if the block is a valid OpenAI data block and matches the filter (if provided).
35pub fn is_openai_data_block(block: &serde_json::Value, filter: Option<DataBlockFilter>) -> bool {
36    let block_type = block.get("type").and_then(|t| t.as_str());
37
38    match block_type {
39        Some("image_url") => {
40            if let Some(f) = filter
41                && f != DataBlockFilter::Image
42            {
43                return false;
44            }
45
46            // Check for valid image_url structure
47            if let Some(image_url) = block.get("image_url")
48                && let Some(obj) = image_url.as_object()
49            {
50                return obj.get("url").and_then(|u| u.as_str()).is_some();
51            }
52            false
53        }
54        Some("input_audio") => {
55            if let Some(f) = filter
56                && f != DataBlockFilter::Audio
57            {
58                return false;
59            }
60
61            // Check for valid input_audio structure
62            if let Some(audio) = block.get("input_audio")
63                && let Some(obj) = audio.as_object()
64            {
65                let has_data = obj.get("data").and_then(|d| d.as_str()).is_some();
66                let has_format = obj.get("format").and_then(|f| f.as_str()).is_some();
67                return has_data && has_format;
68            }
69            false
70        }
71        Some("file") => {
72            if let Some(f) = filter
73                && f != DataBlockFilter::File
74            {
75                return false;
76            }
77
78            // Check for valid file structure
79            if let Some(file) = block.get("file")
80                && let Some(obj) = file.as_object()
81            {
82                let has_file_data = obj.get("file_data").and_then(|d| d.as_str()).is_some();
83                let has_file_id = obj.get("file_id").and_then(|d| d.as_str()).is_some();
84                return has_file_data || has_file_id;
85            }
86            false
87        }
88        _ => false,
89    }
90}
91
92/// Parsed data URI components.
93#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct ParsedDataUri {
95    /// Source type (always "base64" for data URIs).
96    pub source_type: String,
97    /// The base64-encoded data.
98    pub data: String,
99    /// The MIME type of the data.
100    pub mime_type: String,
101}
102
103/// Parse a data URI into its components.
104///
105/// # Arguments
106///
107/// * `uri` - The data URI to parse (e.g., "data:image/jpeg;base64,/9j/4AAQ...")
108///
109/// # Returns
110///
111/// `Some(ParsedDataUri)` if parsing succeeds, `None` otherwise.
112pub fn parse_data_uri(uri: &str) -> Option<ParsedDataUri> {
113    let re = Regex::new(r"^data:(?P<mime_type>[^;]+);base64,(?P<data>.+)$").ok()?;
114    let captures = re.captures(uri)?;
115
116    let mime_type = captures.name("mime_type")?.as_str();
117    let data = captures.name("data")?.as_str();
118
119    if mime_type.is_empty() || data.is_empty() {
120        return None;
121    }
122
123    Some(ParsedDataUri {
124        source_type: "base64".to_string(),
125        data: data.to_string(),
126        mime_type: mime_type.to_string(),
127    })
128}
129
130/// Get a default tokenizer estimate for token counting.
131///
132/// This provides a rough estimate based on whitespace splitting.
133/// For accurate counts, use a proper tokenizer for the specific model.
134///
135/// # Arguments
136///
137/// * `text` - The text to tokenize.
138///
139/// # Returns
140///
141/// Estimated token IDs (just indices in this simple implementation).
142pub fn get_token_ids_default(text: &str) -> Vec<u32> {
143    // Simple whitespace-based tokenization as a fallback
144    // Real implementations should use proper tokenizers
145    text.split_whitespace()
146        .enumerate()
147        .map(|(i, _)| i as u32)
148        .collect()
149}
150
151/// Estimate the number of tokens in a text.
152///
153/// This is a rough estimate. For accurate counts, use model-specific tokenizers.
154///
155/// # Arguments
156///
157/// * `text` - The text to count tokens for.
158///
159/// # Returns
160///
161/// Estimated token count.
162pub fn estimate_token_count(text: &str) -> usize {
163    // Rule of thumb: ~4 characters per token for English text
164    // This is a very rough estimate
165    let char_count = text.chars().count();
166    char_count.div_ceil(4)
167}
168
169/// Convert a v0 content block format to v1 format.
170///
171/// LangChain v0 content blocks had different structure than v1.
172/// This function converts the older format to the newer standard.
173pub fn convert_legacy_v0_content_block_to_v1(
174    block: &HashMap<String, serde_json::Value>,
175) -> HashMap<String, serde_json::Value> {
176    let mut result = HashMap::new();
177
178    // Get the type
179    let block_type = block.get("type").and_then(|t| t.as_str()).unwrap_or("text");
180    result.insert(
181        "type".to_string(),
182        serde_json::Value::String(block_type.to_string()),
183    );
184
185    // Handle different source types
186    let source_type = block.get("source_type").and_then(|t| t.as_str());
187
188    match source_type {
189        Some("base64") => {
190            if let Some(data) = block.get("data") {
191                result.insert("base64".to_string(), data.clone());
192            }
193            if let Some(mime_type) = block.get("mime_type") {
194                result.insert("mime_type".to_string(), mime_type.clone());
195            }
196        }
197        Some("url") => {
198            if let Some(url) = block.get("url") {
199                result.insert("url".to_string(), url.clone());
200            }
201            if let Some(mime_type) = block.get("mime_type") {
202                result.insert("mime_type".to_string(), mime_type.clone());
203            }
204        }
205        Some("id") => {
206            if let Some(id) = block.get("id") {
207                result.insert("file_id".to_string(), id.clone());
208            }
209        }
210        Some("text") => {
211            if let Some(text) = block.get("text") {
212                result.insert("text".to_string(), text.clone());
213            }
214        }
215        _ => {
216            // Copy all other fields
217            for (key, value) in block {
218                if key != "source_type" {
219                    result.insert(key.clone(), value.clone());
220                }
221            }
222        }
223    }
224
225    result
226}
227
228/// Convert an OpenAI format content block to a standard data block.
229pub fn convert_openai_format_to_data_block(
230    block: &serde_json::Value,
231) -> HashMap<String, serde_json::Value> {
232    let mut result = HashMap::new();
233
234    let block_type = block.get("type").and_then(|t| t.as_str()).unwrap_or("");
235
236    match block_type {
237        "image_url" => {
238            result.insert(
239                "type".to_string(),
240                serde_json::Value::String("image".to_string()),
241            );
242
243            if let Some(image_url) = block.get("image_url").and_then(|i| i.as_object()) {
244                if let Some(url) = image_url.get("url").and_then(|u| u.as_str()) {
245                    // Check if it's a data URI
246                    if let Some(parsed) = parse_data_uri(url) {
247                        result.insert("base64".to_string(), serde_json::Value::String(parsed.data));
248                        result.insert(
249                            "mime_type".to_string(),
250                            serde_json::Value::String(parsed.mime_type),
251                        );
252                    } else {
253                        result.insert(
254                            "url".to_string(),
255                            serde_json::Value::String(url.to_string()),
256                        );
257                    }
258                }
259                if let Some(detail) = image_url.get("detail") {
260                    result.insert("detail".to_string(), detail.clone());
261                }
262            }
263        }
264        "input_audio" => {
265            result.insert(
266                "type".to_string(),
267                serde_json::Value::String("audio".to_string()),
268            );
269
270            if let Some(audio) = block.get("input_audio").and_then(|a| a.as_object()) {
271                if let Some(data) = audio.get("data").and_then(|d| d.as_str()) {
272                    result.insert(
273                        "base64".to_string(),
274                        serde_json::Value::String(data.to_string()),
275                    );
276                }
277                if let Some(format) = audio.get("format").and_then(|f| f.as_str()) {
278                    // Map format to mime_type
279                    let mime_type = match format {
280                        "wav" => "audio/wav",
281                        "mp3" => "audio/mpeg",
282                        _ => format,
283                    };
284                    result.insert(
285                        "mime_type".to_string(),
286                        serde_json::Value::String(mime_type.to_string()),
287                    );
288                }
289            }
290        }
291        "file" => {
292            result.insert(
293                "type".to_string(),
294                serde_json::Value::String("file".to_string()),
295            );
296
297            if let Some(file) = block.get("file").and_then(|f| f.as_object()) {
298                if let Some(file_data) = file.get("file_data").and_then(|d| d.as_str()) {
299                    result.insert(
300                        "base64".to_string(),
301                        serde_json::Value::String(file_data.to_string()),
302                    );
303                }
304                if let Some(file_id) = file.get("file_id").and_then(|d| d.as_str()) {
305                    result.insert(
306                        "file_id".to_string(),
307                        serde_json::Value::String(file_id.to_string()),
308                    );
309                }
310                if let Some(filename) = file.get("filename").and_then(|f| f.as_str()) {
311                    result.insert(
312                        "filename".to_string(),
313                        serde_json::Value::String(filename.to_string()),
314                    );
315                }
316            }
317        }
318        _ => {
319            // Copy all fields for unknown types
320            if let Some(obj) = block.as_object() {
321                for (key, value) in obj {
322                    result.insert(key.clone(), value.clone());
323                }
324            }
325        }
326    }
327
328    result
329}
330
331#[cfg(test)]
332mod tests {
333    use super::*;
334    use serde_json::json;
335
336    #[test]
337    fn test_is_openai_data_block_image() {
338        let block = json!({
339            "type": "image_url",
340            "image_url": {
341                "url": "https://example.com/image.png"
342            }
343        });
344
345        assert!(is_openai_data_block(&block, None));
346        assert!(is_openai_data_block(&block, Some(DataBlockFilter::Image)));
347        assert!(!is_openai_data_block(&block, Some(DataBlockFilter::Audio)));
348    }
349
350    #[test]
351    fn test_is_openai_data_block_audio() {
352        let block = json!({
353            "type": "input_audio",
354            "input_audio": {
355                "data": "base64data",
356                "format": "wav"
357            }
358        });
359
360        assert!(is_openai_data_block(&block, None));
361        assert!(is_openai_data_block(&block, Some(DataBlockFilter::Audio)));
362        assert!(!is_openai_data_block(&block, Some(DataBlockFilter::Image)));
363    }
364
365    #[test]
366    fn test_is_openai_data_block_file() {
367        let block = json!({
368            "type": "file",
369            "file": {
370                "file_id": "file-123"
371            }
372        });
373
374        assert!(is_openai_data_block(&block, None));
375        assert!(is_openai_data_block(&block, Some(DataBlockFilter::File)));
376        assert!(!is_openai_data_block(&block, Some(DataBlockFilter::Image)));
377    }
378
379    #[test]
380    fn test_is_openai_data_block_invalid() {
381        let block = json!({
382            "type": "text",
383            "text": "Hello"
384        });
385
386        assert!(!is_openai_data_block(&block, None));
387    }
388
389    #[test]
390    fn test_parse_data_uri() {
391        let uri = "data:image/jpeg;base64,/9j/4AAQSkZJRg==";
392        let parsed = parse_data_uri(uri).unwrap();
393
394        assert_eq!(parsed.source_type, "base64");
395        assert_eq!(parsed.mime_type, "image/jpeg");
396        assert_eq!(parsed.data, "/9j/4AAQSkZJRg==");
397    }
398
399    #[test]
400    fn test_parse_data_uri_invalid() {
401        let uri = "https://example.com/image.png";
402        assert!(parse_data_uri(uri).is_none());
403
404        let uri = "data:;base64,";
405        assert!(parse_data_uri(uri).is_none());
406    }
407
408    #[test]
409    fn test_estimate_token_count() {
410        let text = "Hello, world!";
411        let count = estimate_token_count(text);
412        // 13 chars / 4 ≈ 4 tokens (ceiling)
413        assert!(count > 0);
414        assert!(count < 10);
415    }
416
417    #[test]
418    fn test_get_token_ids_default() {
419        let text = "Hello world test";
420        let ids = get_token_ids_default(text);
421        assert_eq!(ids.len(), 3);
422        assert_eq!(ids, vec![0, 1, 2]);
423    }
424
425    #[test]
426    fn test_convert_openai_format_to_data_block_image_url() {
427        let block = json!({
428            "type": "image_url",
429            "image_url": {
430                "url": "https://example.com/image.png",
431                "detail": "high"
432            }
433        });
434
435        let result = convert_openai_format_to_data_block(&block);
436
437        assert_eq!(result.get("type").unwrap(), "image");
438        assert_eq!(result.get("url").unwrap(), "https://example.com/image.png");
439        assert_eq!(result.get("detail").unwrap(), "high");
440    }
441
442    #[test]
443    fn test_convert_openai_format_to_data_block_data_uri() {
444        let block = json!({
445            "type": "image_url",
446            "image_url": {
447                "url": "data:image/png;base64,iVBORw0KGgo="
448            }
449        });
450
451        let result = convert_openai_format_to_data_block(&block);
452
453        assert_eq!(result.get("type").unwrap(), "image");
454        assert_eq!(result.get("base64").unwrap(), "iVBORw0KGgo=");
455        assert_eq!(result.get("mime_type").unwrap(), "image/png");
456    }
457
458    #[test]
459    fn test_convert_legacy_v0_content_block_to_v1_base64() {
460        let mut block = HashMap::new();
461        block.insert("type".to_string(), json!("image"));
462        block.insert("source_type".to_string(), json!("base64"));
463        block.insert("data".to_string(), json!("base64data"));
464        block.insert("mime_type".to_string(), json!("image/png"));
465
466        let result = convert_legacy_v0_content_block_to_v1(&block);
467
468        assert_eq!(result.get("type").unwrap(), "image");
469        assert_eq!(result.get("base64").unwrap(), "base64data");
470        assert_eq!(result.get("mime_type").unwrap(), "image/png");
471        assert!(!result.contains_key("source_type"));
472    }
473}