use std::collections::HashMap;
use std::path::PathBuf;
use async_trait::async_trait;
use cognis_core::document_loaders::BaseLoader;
use cognis_core::document_loaders::DocumentStream;
use cognis_core::documents::Document;
use cognis_core::error::{CognisError, Result};
use futures::stream;
use serde_json::Value;
#[derive(Debug, Clone)]
enum TomlSource {
File(PathBuf),
String(String),
}
pub struct TomlDocumentLoader {
source: TomlSource,
content_key: Option<String>,
section_mode: bool,
}
impl TomlDocumentLoader {
pub fn from_file(path: impl Into<PathBuf>) -> Self {
Self {
source: TomlSource::File(path.into()),
content_key: None,
section_mode: false,
}
}
pub fn from_string(content: impl Into<String>) -> Self {
Self {
source: TomlSource::String(content.into()),
content_key: None,
section_mode: false,
}
}
pub fn with_content_key(mut self, key: impl Into<String>) -> Self {
self.content_key = Some(key.into());
self
}
pub fn with_section_mode(mut self, enabled: bool) -> Self {
self.section_mode = enabled;
self
}
async fn read_content(&self) -> Result<String> {
match &self.source {
TomlSource::File(path) => tokio::fs::read_to_string(path).await.map_err(Into::into),
TomlSource::String(s) => Ok(s.clone()),
}
}
fn source_label(&self) -> String {
match &self.source {
TomlSource::File(path) => path.display().to_string(),
TomlSource::String(_) => "<string>".to_string(),
}
}
fn toml_to_json(toml_val: &toml::Value) -> Value {
match toml_val {
toml::Value::String(s) => Value::String(s.clone()),
toml::Value::Integer(i) => Value::Number((*i).into()),
toml::Value::Float(f) => serde_json::Number::from_f64(*f)
.map(Value::Number)
.unwrap_or(Value::Null),
toml::Value::Boolean(b) => Value::Bool(*b),
toml::Value::Datetime(dt) => Value::String(dt.to_string()),
toml::Value::Array(arr) => Value::Array(arr.iter().map(Self::toml_to_json).collect()),
toml::Value::Table(table) => {
let mut map = serde_json::Map::new();
for (k, v) in table {
map.insert(k.clone(), Self::toml_to_json(v));
}
Value::Object(map)
}
}
}
fn value_to_document(
&self,
toml_val: &toml::Value,
source: &str,
section_name: Option<&str>,
) -> Document {
let json_val = Self::toml_to_json(toml_val);
let mut metadata = HashMap::new();
metadata.insert("source".to_string(), Value::String(source.to_string()));
if let Some(name) = section_name {
metadata.insert("section".to_string(), Value::String(name.to_string()));
}
match &self.content_key {
Some(key) => {
let content = json_val
.get(key)
.map(|v| match v {
Value::String(s) => s.clone(),
other => other.to_string(),
})
.unwrap_or_default();
if let Value::Object(map) = &json_val {
for (k, v) in map {
if k != key {
metadata.insert(k.clone(), v.clone());
}
}
}
Document::new(content).with_metadata(metadata)
}
None => {
let content = match &json_val {
Value::String(s) => s.clone(),
other => other.to_string(),
};
Document::new(content).with_metadata(metadata)
}
}
}
}
#[async_trait]
impl BaseLoader for TomlDocumentLoader {
async fn lazy_load(&self) -> Result<DocumentStream> {
let raw = self.read_content().await?;
let source = self.source_label();
let root: toml::Value = raw.parse().map_err(|e: toml::de::Error| {
CognisError::Other(format!("Failed to parse TOML: {}", e))
})?;
let docs: Vec<Result<Document>> = if self.section_mode {
match &root {
toml::Value::Table(table) => table
.iter()
.map(|(key, val)| Ok(self.value_to_document(val, &source, Some(key))))
.collect(),
_ => {
vec![Ok(self.value_to_document(&root, &source, None))]
}
}
} else {
vec![Ok(self.value_to_document(&root, &source, None))]
};
Ok(Box::pin(stream::iter(docs)))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
#[tokio::test]
async fn test_toml_basic_load() {
let content = r#"
title = "My Config"
version = "1.0"
"#;
let mut tmp = NamedTempFile::with_suffix(".toml").unwrap();
write!(tmp, "{}", content).unwrap();
let loader = TomlDocumentLoader::from_file(tmp.path());
let docs = loader.load().await.unwrap();
assert_eq!(docs.len(), 1);
assert!(docs[0].page_content.contains("My Config"));
}
#[tokio::test]
async fn test_toml_section_mode() {
let content = r#"
[package]
name = "myapp"
version = "0.1.0"
[dependencies]
serde = "1.0"
tokio = "1.0"
[dev-dependencies]
tempfile = "3"
"#;
let loader = TomlDocumentLoader::from_string(content).with_section_mode(true);
let docs = loader.load().await.unwrap();
assert_eq!(docs.len(), 3);
let sections: Vec<&str> = docs
.iter()
.map(|d| d.metadata.get("section").unwrap().as_str().unwrap())
.collect();
assert!(sections.contains(&"package"));
assert!(sections.contains(&"dependencies"));
assert!(sections.contains(&"dev-dependencies"));
}
#[tokio::test]
async fn test_toml_content_key() {
let content = r#"
name = "test-project"
description = "A test project"
version = "1.0.0"
"#;
let loader = TomlDocumentLoader::from_string(content).with_content_key("description");
let docs = loader.load().await.unwrap();
assert_eq!(docs.len(), 1);
assert_eq!(docs[0].page_content, "A test project");
assert_eq!(
docs[0].metadata.get("name").unwrap(),
&Value::String("test-project".to_string())
);
}
#[tokio::test]
async fn test_toml_from_string() {
let content = "key = \"value\"\n";
let loader = TomlDocumentLoader::from_string(content);
let docs = loader.load().await.unwrap();
assert_eq!(docs.len(), 1);
assert_eq!(
docs[0].metadata.get("source").unwrap(),
&Value::String("<string>".to_string())
);
}
#[tokio::test]
async fn test_toml_nested_tables() {
let content = r#"
[server]
host = "localhost"
port = 8080
[server.tls]
enabled = true
cert = "/path/to/cert"
"#;
let loader = TomlDocumentLoader::from_string(content).with_section_mode(true);
let docs = loader.load().await.unwrap();
assert_eq!(docs.len(), 1); let doc = &docs[0];
assert!(doc.page_content.contains("localhost"));
assert!(doc.page_content.contains("8080"));
assert!(doc.page_content.contains("tls"));
}
#[tokio::test]
async fn test_toml_empty_file() {
let content = "";
let loader = TomlDocumentLoader::from_string(content);
let docs = loader.load().await.unwrap();
assert_eq!(docs.len(), 1);
assert_eq!(docs[0].page_content, "{}");
}
#[tokio::test]
async fn test_toml_metadata_extraction() {
let content = r#"
[package]
name = "example"
version = "2.0"
edition = "2021"
"#;
let loader = TomlDocumentLoader::from_string(content)
.with_section_mode(true)
.with_content_key("name");
let docs = loader.load().await.unwrap();
assert_eq!(docs.len(), 1);
assert_eq!(docs[0].page_content, "example");
assert_eq!(
docs[0].metadata.get("version").unwrap(),
&Value::String("2.0".to_string())
);
assert_eq!(
docs[0].metadata.get("edition").unwrap(),
&Value::String("2021".to_string())
);
assert_eq!(
docs[0].metadata.get("section").unwrap(),
&Value::String("package".to_string())
);
}
#[tokio::test]
async fn test_toml_section_mode_with_scalars() {
let content = r#"
title = "root value"
[settings]
debug = true
"#;
let loader = TomlDocumentLoader::from_string(content).with_section_mode(true);
let docs = loader.load().await.unwrap();
assert_eq!(docs.len(), 2);
}
}