1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
use std::collections::HashMap;

use async_trait::async_trait;
use serde_json::Value;

use crate::schemas::Document;

use super::TextSplitterError;

#[async_trait]
pub trait TextSplitter: Send + Sync {
    async fn split_text(&self, text: &str) -> Result<Vec<String>, TextSplitterError>;

    async fn split_documents(
        &self,
        documents: &[Document],
    ) -> Result<Vec<Document>, TextSplitterError> {
        let mut texts: Vec<String> = Vec::new();
        let mut metadatas: Vec<HashMap<String, Value>> = Vec::new();
        documents.iter().for_each(|d| {
            texts.push(d.page_content.clone());
            metadatas.push(d.metadata.clone());
        });

        self.create_documents(&texts, &metadatas).await
    }

    async fn create_documents(
        &self,
        text: &[String],
        metadatas: &[HashMap<String, Value>],
    ) -> Result<Vec<Document>, TextSplitterError> {
        let mut metadatas = metadatas.to_vec();
        if metadatas.is_empty() {
            metadatas = vec![HashMap::new(); text.len()];
        }

        if text.len() != metadatas.len() {
            return Err(TextSplitterError::MetadataTextMismatch);
        }

        let mut documents: Vec<Document> = Vec::new();
        for i in 0..text.len() {
            let chunks = self.split_text(&text[i]).await?;
            for chunk in chunks {
                let document = Document::new(chunk).with_metadata(metadatas[i].clone());
                documents.push(document);
            }
        }

        Ok(documents)
    }
}