agent_chain_core/language_models/
utils.rs1use std::collections::HashMap;
8
9use regex::Regex;
10use serde::{Deserialize, Serialize};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum DataBlockFilter {
15 Image,
17 Audio,
19 File,
21}
22
23pub fn is_openai_data_block(block: &serde_json::Value, filter: Option<DataBlockFilter>) -> bool {
36 let block_type = block.get("type").and_then(|t| t.as_str());
37
38 match block_type {
39 Some("image_url") => {
40 if let Some(f) = filter
41 && f != DataBlockFilter::Image
42 {
43 return false;
44 }
45
46 if let Some(image_url) = block.get("image_url")
48 && let Some(obj) = image_url.as_object()
49 {
50 return obj.get("url").and_then(|u| u.as_str()).is_some();
51 }
52 false
53 }
54 Some("input_audio") => {
55 if let Some(f) = filter
56 && f != DataBlockFilter::Audio
57 {
58 return false;
59 }
60
61 if let Some(audio) = block.get("input_audio")
63 && let Some(obj) = audio.as_object()
64 {
65 let has_data = obj.get("data").and_then(|d| d.as_str()).is_some();
66 let has_format = obj.get("format").and_then(|f| f.as_str()).is_some();
67 return has_data && has_format;
68 }
69 false
70 }
71 Some("file") => {
72 if let Some(f) = filter
73 && f != DataBlockFilter::File
74 {
75 return false;
76 }
77
78 if let Some(file) = block.get("file")
80 && let Some(obj) = file.as_object()
81 {
82 let has_file_data = obj.get("file_data").and_then(|d| d.as_str()).is_some();
83 let has_file_id = obj.get("file_id").and_then(|d| d.as_str()).is_some();
84 return has_file_data || has_file_id;
85 }
86 false
87 }
88 _ => false,
89 }
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct ParsedDataUri {
95 pub source_type: String,
97 pub data: String,
99 pub mime_type: String,
101}
102
103pub fn parse_data_uri(uri: &str) -> Option<ParsedDataUri> {
113 let re = Regex::new(r"^data:(?P<mime_type>[^;]+);base64,(?P<data>.+)$").ok()?;
114 let captures = re.captures(uri)?;
115
116 let mime_type = captures.name("mime_type")?.as_str();
117 let data = captures.name("data")?.as_str();
118
119 if mime_type.is_empty() || data.is_empty() {
120 return None;
121 }
122
123 Some(ParsedDataUri {
124 source_type: "base64".to_string(),
125 data: data.to_string(),
126 mime_type: mime_type.to_string(),
127 })
128}
129
130pub fn get_token_ids_default(text: &str) -> Vec<u32> {
143 text.split_whitespace()
146 .enumerate()
147 .map(|(i, _)| i as u32)
148 .collect()
149}
150
151pub fn estimate_token_count(text: &str) -> usize {
163 let char_count = text.chars().count();
166 char_count.div_ceil(4)
167}
168
169pub fn convert_legacy_v0_content_block_to_v1(
174 block: &HashMap<String, serde_json::Value>,
175) -> HashMap<String, serde_json::Value> {
176 let mut result = HashMap::new();
177
178 let block_type = block.get("type").and_then(|t| t.as_str()).unwrap_or("text");
180 result.insert(
181 "type".to_string(),
182 serde_json::Value::String(block_type.to_string()),
183 );
184
185 let source_type = block.get("source_type").and_then(|t| t.as_str());
187
188 match source_type {
189 Some("base64") => {
190 if let Some(data) = block.get("data") {
191 result.insert("base64".to_string(), data.clone());
192 }
193 if let Some(mime_type) = block.get("mime_type") {
194 result.insert("mime_type".to_string(), mime_type.clone());
195 }
196 }
197 Some("url") => {
198 if let Some(url) = block.get("url") {
199 result.insert("url".to_string(), url.clone());
200 }
201 if let Some(mime_type) = block.get("mime_type") {
202 result.insert("mime_type".to_string(), mime_type.clone());
203 }
204 }
205 Some("id") => {
206 if let Some(id) = block.get("id") {
207 result.insert("file_id".to_string(), id.clone());
208 }
209 }
210 Some("text") => {
211 if let Some(text) = block.get("text") {
212 result.insert("text".to_string(), text.clone());
213 }
214 }
215 _ => {
216 for (key, value) in block {
218 if key != "source_type" {
219 result.insert(key.clone(), value.clone());
220 }
221 }
222 }
223 }
224
225 result
226}
227
228pub fn convert_openai_format_to_data_block(
230 block: &serde_json::Value,
231) -> HashMap<String, serde_json::Value> {
232 let mut result = HashMap::new();
233
234 let block_type = block.get("type").and_then(|t| t.as_str()).unwrap_or("");
235
236 match block_type {
237 "image_url" => {
238 result.insert(
239 "type".to_string(),
240 serde_json::Value::String("image".to_string()),
241 );
242
243 if let Some(image_url) = block.get("image_url").and_then(|i| i.as_object()) {
244 if let Some(url) = image_url.get("url").and_then(|u| u.as_str()) {
245 if let Some(parsed) = parse_data_uri(url) {
247 result.insert("base64".to_string(), serde_json::Value::String(parsed.data));
248 result.insert(
249 "mime_type".to_string(),
250 serde_json::Value::String(parsed.mime_type),
251 );
252 } else {
253 result.insert(
254 "url".to_string(),
255 serde_json::Value::String(url.to_string()),
256 );
257 }
258 }
259 if let Some(detail) = image_url.get("detail") {
260 result.insert("detail".to_string(), detail.clone());
261 }
262 }
263 }
264 "input_audio" => {
265 result.insert(
266 "type".to_string(),
267 serde_json::Value::String("audio".to_string()),
268 );
269
270 if let Some(audio) = block.get("input_audio").and_then(|a| a.as_object()) {
271 if let Some(data) = audio.get("data").and_then(|d| d.as_str()) {
272 result.insert(
273 "base64".to_string(),
274 serde_json::Value::String(data.to_string()),
275 );
276 }
277 if let Some(format) = audio.get("format").and_then(|f| f.as_str()) {
278 let mime_type = match format {
280 "wav" => "audio/wav",
281 "mp3" => "audio/mpeg",
282 _ => format,
283 };
284 result.insert(
285 "mime_type".to_string(),
286 serde_json::Value::String(mime_type.to_string()),
287 );
288 }
289 }
290 }
291 "file" => {
292 result.insert(
293 "type".to_string(),
294 serde_json::Value::String("file".to_string()),
295 );
296
297 if let Some(file) = block.get("file").and_then(|f| f.as_object()) {
298 if let Some(file_data) = file.get("file_data").and_then(|d| d.as_str()) {
299 result.insert(
300 "base64".to_string(),
301 serde_json::Value::String(file_data.to_string()),
302 );
303 }
304 if let Some(file_id) = file.get("file_id").and_then(|d| d.as_str()) {
305 result.insert(
306 "file_id".to_string(),
307 serde_json::Value::String(file_id.to_string()),
308 );
309 }
310 if let Some(filename) = file.get("filename").and_then(|f| f.as_str()) {
311 result.insert(
312 "filename".to_string(),
313 serde_json::Value::String(filename.to_string()),
314 );
315 }
316 }
317 }
318 _ => {
319 if let Some(obj) = block.as_object() {
321 for (key, value) in obj {
322 result.insert(key.clone(), value.clone());
323 }
324 }
325 }
326 }
327
328 result
329}
330
331#[cfg(test)]
332mod tests {
333 use super::*;
334 use serde_json::json;
335
336 #[test]
337 fn test_is_openai_data_block_image() {
338 let block = json!({
339 "type": "image_url",
340 "image_url": {
341 "url": "https://example.com/image.png"
342 }
343 });
344
345 assert!(is_openai_data_block(&block, None));
346 assert!(is_openai_data_block(&block, Some(DataBlockFilter::Image)));
347 assert!(!is_openai_data_block(&block, Some(DataBlockFilter::Audio)));
348 }
349
350 #[test]
351 fn test_is_openai_data_block_audio() {
352 let block = json!({
353 "type": "input_audio",
354 "input_audio": {
355 "data": "base64data",
356 "format": "wav"
357 }
358 });
359
360 assert!(is_openai_data_block(&block, None));
361 assert!(is_openai_data_block(&block, Some(DataBlockFilter::Audio)));
362 assert!(!is_openai_data_block(&block, Some(DataBlockFilter::Image)));
363 }
364
365 #[test]
366 fn test_is_openai_data_block_file() {
367 let block = json!({
368 "type": "file",
369 "file": {
370 "file_id": "file-123"
371 }
372 });
373
374 assert!(is_openai_data_block(&block, None));
375 assert!(is_openai_data_block(&block, Some(DataBlockFilter::File)));
376 assert!(!is_openai_data_block(&block, Some(DataBlockFilter::Image)));
377 }
378
379 #[test]
380 fn test_is_openai_data_block_invalid() {
381 let block = json!({
382 "type": "text",
383 "text": "Hello"
384 });
385
386 assert!(!is_openai_data_block(&block, None));
387 }
388
389 #[test]
390 fn test_parse_data_uri() {
391 let uri = "data:image/jpeg;base64,/9j/4AAQSkZJRg==";
392 let parsed = parse_data_uri(uri).unwrap();
393
394 assert_eq!(parsed.source_type, "base64");
395 assert_eq!(parsed.mime_type, "image/jpeg");
396 assert_eq!(parsed.data, "/9j/4AAQSkZJRg==");
397 }
398
399 #[test]
400 fn test_parse_data_uri_invalid() {
401 let uri = "https://example.com/image.png";
402 assert!(parse_data_uri(uri).is_none());
403
404 let uri = "data:;base64,";
405 assert!(parse_data_uri(uri).is_none());
406 }
407
408 #[test]
409 fn test_estimate_token_count() {
410 let text = "Hello, world!";
411 let count = estimate_token_count(text);
412 assert!(count > 0);
414 assert!(count < 10);
415 }
416
417 #[test]
418 fn test_get_token_ids_default() {
419 let text = "Hello world test";
420 let ids = get_token_ids_default(text);
421 assert_eq!(ids.len(), 3);
422 assert_eq!(ids, vec![0, 1, 2]);
423 }
424
425 #[test]
426 fn test_convert_openai_format_to_data_block_image_url() {
427 let block = json!({
428 "type": "image_url",
429 "image_url": {
430 "url": "https://example.com/image.png",
431 "detail": "high"
432 }
433 });
434
435 let result = convert_openai_format_to_data_block(&block);
436
437 assert_eq!(result.get("type").unwrap(), "image");
438 assert_eq!(result.get("url").unwrap(), "https://example.com/image.png");
439 assert_eq!(result.get("detail").unwrap(), "high");
440 }
441
442 #[test]
443 fn test_convert_openai_format_to_data_block_data_uri() {
444 let block = json!({
445 "type": "image_url",
446 "image_url": {
447 "url": "data:image/png;base64,iVBORw0KGgo="
448 }
449 });
450
451 let result = convert_openai_format_to_data_block(&block);
452
453 assert_eq!(result.get("type").unwrap(), "image");
454 assert_eq!(result.get("base64").unwrap(), "iVBORw0KGgo=");
455 assert_eq!(result.get("mime_type").unwrap(), "image/png");
456 }
457
458 #[test]
459 fn test_convert_legacy_v0_content_block_to_v1_base64() {
460 let mut block = HashMap::new();
461 block.insert("type".to_string(), json!("image"));
462 block.insert("source_type".to_string(), json!("base64"));
463 block.insert("data".to_string(), json!("base64data"));
464 block.insert("mime_type".to_string(), json!("image/png"));
465
466 let result = convert_legacy_v0_content_block_to_v1(&block);
467
468 assert_eq!(result.get("type").unwrap(), "image");
469 assert_eq!(result.get("base64").unwrap(), "base64data");
470 assert_eq!(result.get("mime_type").unwrap(), "image/png");
471 assert!(!result.contains_key("source_type"));
472 }
473}