1use std::collections::HashMap;
14use std::fmt;
15use std::fs;
16use std::io::{self, Read};
17use std::path::{Path, PathBuf};
18
19use serde::{Deserialize, Serialize};
20use serde_json::Value;
21
22#[cfg(feature = "specta")]
23use specta::Type;
24
25use crate::load::Serializable;
26
27#[cfg_attr(feature = "specta", derive(Type))]
34#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
35pub struct BaseMedia {
36 #[serde(default, skip_serializing_if = "Option::is_none")]
41 pub id: Option<String>,
42
43 #[serde(default)]
45 pub metadata: HashMap<String, Value>,
46}
47
48impl BaseMedia {
49 pub fn new(id: Option<String>, metadata: HashMap<String, Value>) -> Self {
51 Self { id, metadata }
52 }
53
54 pub fn with_id(mut self, id: impl Into<String>) -> Self {
56 self.id = Some(id.into());
57 self
58 }
59
60 pub fn with_metadata(mut self, metadata: HashMap<String, Value>) -> Self {
62 self.metadata = metadata;
63 self
64 }
65}
66
67#[cfg_attr(feature = "specta", derive(Type))]
104#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
105pub struct Blob {
106 #[serde(default, skip_serializing_if = "Option::is_none")]
108 pub id: Option<String>,
109
110 #[serde(default)]
112 pub metadata: HashMap<String, Value>,
113
114 #[serde(default, skip_serializing_if = "Option::is_none")]
116 pub data: Option<BlobData>,
117
118 #[serde(default, skip_serializing_if = "Option::is_none")]
120 pub mimetype: Option<String>,
121
122 #[serde(default = "default_encoding")]
125 pub encoding: String,
126
127 #[serde(default, skip_serializing_if = "Option::is_none")]
129 pub path: Option<PathBuf>,
130}
131
132fn default_encoding() -> String {
133 "utf-8".to_string()
134}
135
136#[cfg_attr(feature = "specta", derive(Type))]
138#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
139#[serde(untagged)]
140pub enum BlobData {
141 Text(String),
143 #[serde(with = "serde_bytes_base64")]
145 Bytes(Vec<u8>),
146}
147
148mod serde_bytes_base64 {
149 use base64::{Engine as _, engine::general_purpose::STANDARD};
150 use serde::{self, Deserialize, Deserializer, Serializer};
151
152 pub fn serialize<S>(bytes: &Vec<u8>, serializer: S) -> Result<S::Ok, S::Error>
153 where
154 S: Serializer,
155 {
156 let s = STANDARD.encode(bytes);
157 serializer.serialize_str(&s)
158 }
159
160 pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
161 where
162 D: Deserializer<'de>,
163 {
164 let s = String::deserialize(deserializer)?;
165 STANDARD.decode(&s).map_err(serde::de::Error::custom)
166 }
167}
168
169impl Blob {
170 pub fn builder() -> BlobBuilder {
172 BlobBuilder::default()
173 }
174
175 pub fn from_data(data: impl Into<String>) -> Self {
177 Self {
178 id: None,
179 metadata: HashMap::new(),
180 data: Some(BlobData::Text(data.into())),
181 mimetype: None,
182 encoding: "utf-8".to_string(),
183 path: None,
184 }
185 }
186
187 pub fn from_bytes(data: Vec<u8>) -> Self {
189 Self {
190 id: None,
191 metadata: HashMap::new(),
192 data: Some(BlobData::Bytes(data)),
193 mimetype: None,
194 encoding: "utf-8".to_string(),
195 path: None,
196 }
197 }
198
199 pub fn from_path(
211 path: impl AsRef<Path>,
212 mime_type: Option<String>,
213 encoding: Option<String>,
214 metadata: Option<HashMap<String, Value>>,
215 ) -> Self {
216 let path = path.as_ref();
217 let mimetype = mime_type.or_else(|| guess_mime_type(path));
218
219 Self {
220 id: None,
221 metadata: metadata.unwrap_or_default(),
222 data: None,
223 mimetype,
224 encoding: encoding.unwrap_or_else(|| "utf-8".to_string()),
225 path: Some(path.to_path_buf()),
226 }
227 }
228
229 pub fn source(&self) -> Option<String> {
235 if let Some(Value::String(source)) = self.metadata.get("source") {
236 return Some(source.clone());
237 }
238 self.path.as_ref().map(|p| p.to_string_lossy().to_string())
239 }
240
241 pub fn as_string(&self) -> io::Result<String> {
247 match &self.data {
248 Some(BlobData::Text(s)) => Ok(s.clone()),
249 Some(BlobData::Bytes(b)) => String::from_utf8(b.clone())
250 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e)),
251 None => {
252 if let Some(path) = &self.path {
253 fs::read_to_string(path)
254 } else {
255 Err(io::Error::new(
256 io::ErrorKind::InvalidData,
257 format!("Unable to get string for blob {:?}", self),
258 ))
259 }
260 }
261 }
262 }
263
264 pub fn as_bytes(&self) -> io::Result<Vec<u8>> {
270 match &self.data {
271 Some(BlobData::Bytes(b)) => Ok(b.clone()),
272 Some(BlobData::Text(s)) => Ok(s.as_bytes().to_vec()),
273 None => {
274 if let Some(path) = &self.path {
275 fs::read(path)
276 } else {
277 Err(io::Error::new(
278 io::ErrorKind::InvalidData,
279 format!("Unable to get bytes for blob {:?}", self),
280 ))
281 }
282 }
283 }
284 }
285
286 pub fn as_bytes_io(&self) -> io::Result<Box<dyn Read>> {
292 match &self.data {
293 Some(BlobData::Bytes(b)) => Ok(Box::new(std::io::Cursor::new(b.clone()))),
294 Some(BlobData::Text(s)) => Ok(Box::new(std::io::Cursor::new(s.as_bytes().to_vec()))),
295 None => {
296 if let Some(path) = &self.path {
297 let file = fs::File::open(path)?;
298 Ok(Box::new(std::io::BufReader::new(file)))
299 } else {
300 Err(io::Error::other(format!(
301 "Unable to convert blob {:?}",
302 self
303 )))
304 }
305 }
306 }
307 }
308}
309
310impl fmt::Display for Blob {
311 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
312 write!(f, "Blob {:p}", self)?;
313 if let Some(source) = self.source() {
314 write!(f, " {}", source)?;
315 }
316 Ok(())
317 }
318}
319
320#[derive(Debug, Default)]
322pub struct BlobBuilder {
323 id: Option<String>,
324 metadata: HashMap<String, Value>,
325 data: Option<BlobData>,
326 mimetype: Option<String>,
327 encoding: String,
328 path: Option<PathBuf>,
329}
330
331impl BlobBuilder {
332 pub fn id(mut self, id: impl Into<String>) -> Self {
334 self.id = Some(id.into());
335 self
336 }
337
338 pub fn metadata(mut self, metadata: HashMap<String, Value>) -> Self {
340 self.metadata = metadata;
341 self
342 }
343
344 pub fn data(mut self, data: impl Into<String>) -> Self {
346 self.data = Some(BlobData::Text(data.into()));
347 self
348 }
349
350 pub fn bytes(mut self, data: Vec<u8>) -> Self {
352 self.data = Some(BlobData::Bytes(data));
353 self
354 }
355
356 pub fn mime_type(mut self, mime_type: impl Into<String>) -> Self {
358 self.mimetype = Some(mime_type.into());
359 self
360 }
361
362 pub fn encoding(mut self, encoding: impl Into<String>) -> Self {
364 self.encoding = encoding.into();
365 self
366 }
367
368 pub fn path(mut self, path: impl AsRef<Path>) -> Self {
370 self.path = Some(path.as_ref().to_path_buf());
371 self
372 }
373
374 pub fn build(self) -> Result<Blob, &'static str> {
380 if self.data.is_none() && self.path.is_none() {
381 return Err("Either data or path must be provided");
382 }
383
384 Ok(Blob {
385 id: self.id,
386 metadata: self.metadata,
387 data: self.data,
388 mimetype: self.mimetype,
389 encoding: if self.encoding.is_empty() {
390 "utf-8".to_string()
391 } else {
392 self.encoding
393 },
394 path: self.path,
395 })
396 }
397}
398
399fn guess_mime_type(path: &Path) -> Option<String> {
401 path.extension().and_then(|ext| {
402 let ext = ext.to_string_lossy().to_lowercase();
403 match ext.as_str() {
404 "txt" => Some("text/plain".to_string()),
405 "html" | "htm" => Some("text/html".to_string()),
406 "css" => Some("text/css".to_string()),
407 "js" => Some("application/javascript".to_string()),
408 "json" => Some("application/json".to_string()),
409 "xml" => Some("application/xml".to_string()),
410 "pdf" => Some("application/pdf".to_string()),
411 "png" => Some("image/png".to_string()),
412 "jpg" | "jpeg" => Some("image/jpeg".to_string()),
413 "gif" => Some("image/gif".to_string()),
414 "svg" => Some("image/svg+xml".to_string()),
415 "mp3" => Some("audio/mpeg".to_string()),
416 "wav" => Some("audio/wav".to_string()),
417 "mp4" => Some("video/mp4".to_string()),
418 "webm" => Some("video/webm".to_string()),
419 "zip" => Some("application/zip".to_string()),
420 "gz" | "gzip" => Some("application/gzip".to_string()),
421 "tar" => Some("application/x-tar".to_string()),
422 "csv" => Some("text/csv".to_string()),
423 "md" => Some("text/markdown".to_string()),
424 "yaml" | "yml" => Some("application/x-yaml".to_string()),
425 "toml" => Some("application/toml".to_string()),
426 "rs" => Some("text/x-rust".to_string()),
427 "py" => Some("text/x-python".to_string()),
428 _ => None,
429 }
430 })
431}
432
433#[cfg_attr(feature = "specta", derive(Type))]
450#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
451pub struct Document {
452 pub page_content: String,
454
455 #[serde(default, skip_serializing_if = "Option::is_none")]
460 pub id: Option<String>,
461
462 #[serde(default)]
464 pub metadata: HashMap<String, Value>,
465
466 #[serde(rename = "type", default = "document_type_default")]
468 pub type_: String,
469}
470
471fn document_type_default() -> String {
472 "Document".to_string()
473}
474
475impl Document {
476 pub fn new(page_content: impl Into<String>) -> Self {
478 Self {
479 page_content: page_content.into(),
480 id: None,
481 metadata: HashMap::new(),
482 type_: "Document".to_string(),
483 }
484 }
485
486 pub fn with_id(mut self, id: impl Into<String>) -> Self {
488 self.id = Some(id.into());
489 self
490 }
491
492 pub fn with_metadata(mut self, metadata: HashMap<String, Value>) -> Self {
494 self.metadata = metadata;
495 self
496 }
497
498 pub fn with_metadata_entry(mut self, key: impl Into<String>, value: Value) -> Self {
500 self.metadata.insert(key.into(), value);
501 self
502 }
503
504 pub fn content(&self) -> &str {
506 &self.page_content
507 }
508
509 pub fn is_empty(&self) -> bool {
511 self.page_content.is_empty()
512 }
513}
514
515impl fmt::Display for Document {
516 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
517 if self.metadata.is_empty() {
518 write!(f, "page_content='{}'", self.page_content)
519 } else {
520 write!(
521 f,
522 "page_content='{}' metadata={:?}",
523 self.page_content, self.metadata
524 )
525 }
526 }
527}
528
529impl From<&str> for Document {
530 fn from(s: &str) -> Self {
531 Document::new(s)
532 }
533}
534
535impl From<String> for Document {
536 fn from(s: String) -> Self {
537 Document::new(s)
538 }
539}
540
541impl Serializable for Document {
542 fn is_lc_serializable() -> bool
543 where
544 Self: Sized,
545 {
546 true
547 }
548
549 fn get_lc_namespace() -> Vec<String>
550 where
551 Self: Sized,
552 {
553 vec![
554 "langchain".to_string(),
555 "schema".to_string(),
556 "document".to_string(),
557 ]
558 }
559}
560
561#[cfg(test)]
562mod tests {
563 use super::*;
564
565 #[test]
566 fn test_document_creation() {
567 let doc = Document::new("Hello, world!");
568 assert_eq!(doc.page_content, "Hello, world!");
569 assert!(doc.id.is_none());
570 assert!(doc.metadata.is_empty());
571 assert_eq!(doc.type_, "Document");
572 }
573
574 #[test]
575 fn test_document_with_metadata() {
576 let doc = Document::new("Test content")
577 .with_id("doc-123")
578 .with_metadata(HashMap::from([(
579 "source".to_string(),
580 Value::String("test.txt".to_string()),
581 )]));
582
583 assert_eq!(doc.id, Some("doc-123".to_string()));
584 assert_eq!(
585 doc.metadata.get("source"),
586 Some(&Value::String("test.txt".to_string()))
587 );
588 }
589
590 #[test]
591 fn test_document_display() {
592 let doc = Document::new("Hello");
593 assert_eq!(format!("{}", doc), "page_content='Hello'");
594
595 let doc_with_meta = Document::new("Hello")
596 .with_metadata(HashMap::from([("key".to_string(), Value::Bool(true))]));
597 let display = format!("{}", doc_with_meta);
598 assert!(display.contains("page_content='Hello'"));
599 assert!(display.contains("metadata="));
600 }
601
602 #[test]
603 fn test_blob_from_data() {
604 let blob = Blob::from_data("Hello, world!");
605 assert_eq!(blob.as_string().unwrap(), "Hello, world!");
606 assert_eq!(blob.as_bytes().unwrap(), b"Hello, world!");
607 }
608
609 #[test]
610 fn test_blob_from_bytes() {
611 let blob = Blob::from_bytes(b"Hello, bytes!".to_vec());
612 assert_eq!(blob.as_bytes().unwrap(), b"Hello, bytes!");
613 assert_eq!(blob.as_string().unwrap(), "Hello, bytes!");
614 }
615
616 #[test]
617 fn test_blob_builder() {
618 let blob = Blob::builder()
619 .data("Test data")
620 .mime_type("text/plain")
621 .encoding("utf-8")
622 .build()
623 .unwrap();
624
625 assert_eq!(blob.as_string().unwrap(), "Test data");
626 assert_eq!(blob.mimetype, Some("text/plain".to_string()));
627 assert_eq!(blob.encoding, "utf-8");
628 }
629
630 #[test]
631 fn test_blob_builder_error() {
632 let result = Blob::builder().build();
633 assert!(result.is_err());
634 }
635
636 #[test]
637 fn test_blob_source() {
638 let blob = Blob::from_path("/test/path.txt", None, None, None);
639 assert_eq!(blob.source(), Some("/test/path.txt".to_string()));
640
641 let blob_with_source = Blob::builder()
642 .data("test")
643 .metadata(HashMap::from([(
644 "source".to_string(),
645 Value::String("custom_source".to_string()),
646 )]))
647 .build()
648 .unwrap();
649 assert_eq!(blob_with_source.source(), Some("custom_source".to_string()));
650 }
651
652 #[test]
653 fn test_guess_mime_type() {
654 assert_eq!(
655 guess_mime_type(Path::new("test.txt")),
656 Some("text/plain".to_string())
657 );
658 assert_eq!(
659 guess_mime_type(Path::new("test.json")),
660 Some("application/json".to_string())
661 );
662 assert_eq!(
663 guess_mime_type(Path::new("test.png")),
664 Some("image/png".to_string())
665 );
666 assert_eq!(guess_mime_type(Path::new("test.unknown")), None);
667 }
668
669 #[test]
670 fn test_document_serialization() {
671 let doc = Document::new("Test content")
672 .with_id("doc-123")
673 .with_metadata(HashMap::from([(
674 "source".to_string(),
675 Value::String("test.txt".to_string()),
676 )]));
677
678 let json = serde_json::to_string(&doc).unwrap();
679 let deserialized: Document = serde_json::from_str(&json).unwrap();
680
681 assert_eq!(doc, deserialized);
682 }
683}