#![warn(missing_docs, clippy::missing_docs_in_private_items)]
#![feature(str_from_utf16_endian)]
use std::{
fmt::Display,
path::{Path, PathBuf},
};
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use lazy_static::lazy_static;
use regex::Regex;
use reqwest::Url;
use serde::{Deserialize, Serialize};
use tracing::{debug, trace, warn};
#[cfg(test)]
use {tempfile::tempdir, tracing_test::traced_test};
pub mod database;
pub mod retriever;
pub mod error;
pub mod format;
pub mod llm;
pub mod pdf;
pub mod resource;
use crate::{
database::*,
error::*,
resource::{Author, Paper},
retriever::*,
};
pub const ARXIV_CONFIG: &str = include_str!("../config/retrievers/arxiv.toml");
pub const DOI_CONFIG: &str = include_str!("../config/retrievers/doi.toml");
pub const IACR_CONFIG: &str = include_str!("../config/retrievers/iacr.toml");
pub mod prelude {
pub use crate::{
database::DatabaseInstruction, error::LearnerError, resource::Resource,
retriever::ResponseProcessor,
};
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
#[serde(default = "Database::default_path")]
pub database_path: PathBuf,
#[serde(default = "Database::default_storage_path")]
pub storage_path: PathBuf,
#[serde(default = "Config::default_retrievers_path")]
pub retrievers_path: PathBuf,
}
#[derive(Debug, Clone)]
pub struct Learner {
pub config: Config,
pub database: Database,
pub retriever: Retriever,
}
#[derive(Default)]
pub struct LearnerBuilder {
config: Option<Config>,
config_path: Option<PathBuf>,
}
impl Config {
pub fn default_path() -> Result<PathBuf> {
let config_dir = dirs::home_dir().unwrap_or_else(|| PathBuf::from(".")).join(".learner");
std::fs::create_dir_all(&config_dir)?;
Ok(config_dir)
}
pub fn default_retrievers_path() -> PathBuf {
Self::default_path().unwrap_or_else(|_| PathBuf::from(".")).join("retrievers")
}
pub fn load() -> Result<Self> {
let config_file = Self::default_path()?.join("config.toml");
if config_file.exists() {
let content = std::fs::read_to_string(&config_file)?;
toml::from_str(&content).map_err(|e| LearnerError::Config(e.to_string()))
} else {
let config = Self::default();
config.save()?;
Ok(config)
}
}
pub fn save(&self) -> Result<()> {
let config_str =
toml::to_string_pretty(self).map_err(|e| LearnerError::Config(e.to_string()))?;
let config_file = Self::default_path()?.join("config.toml");
debug!("Initializing config to: {config_file:?}");
std::fs::write(config_file, config_str)?;
std::fs::create_dir_all(&self.retrievers_path)?;
Ok(())
}
pub fn init() -> Result<Self> {
let config = Self::default();
config.save()?;
let retrievers_dir = &config.retrievers_path;
std::fs::create_dir_all(retrievers_dir)?;
std::fs::write(retrievers_dir.join("arxiv.toml"), ARXIV_CONFIG)?;
std::fs::write(retrievers_dir.join("doi.toml"), DOI_CONFIG)?;
std::fs::write(retrievers_dir.join("iacr.toml"), IACR_CONFIG)?;
Ok(config)
}
pub fn with_database_path(mut self, database_path: &Path) -> Self {
self.database_path = database_path.to_path_buf();
self
}
pub fn with_retrievers_path(mut self, retrievers_path: &Path) -> Self {
self.retrievers_path = retrievers_path.to_path_buf();
self
}
pub fn with_storage_path(mut self, storage_path: &Path) -> Self {
self.storage_path = storage_path.to_path_buf();
self
}
}
impl Default for Config {
fn default() -> Self {
Self {
database_path: Database::default_path(),
storage_path: Database::default_storage_path(),
retrievers_path: Self::default_retrievers_path(),
}
}
}
impl LearnerBuilder {
pub fn new() -> Self { Self::default() }
pub fn with_config(mut self, config: Config) -> Self {
self.config = Some(config);
self
}
pub fn with_path(mut self, path: impl AsRef<Path>) -> Self {
self.config_path = Some(path.as_ref().to_path_buf());
self
}
pub async fn build(self) -> Result<Learner> {
let config = if let Some(config) = self.config {
config
} else if let Some(path) = self.config_path {
let config_file = path.join("config.toml");
let content = std::fs::read_to_string(config_file)?;
toml::from_str(&content).map_err(|e| LearnerError::Config(e.to_string()))?
} else {
Config::load()?
};
std::fs::create_dir_all(&config.retrievers_path)?;
if let Some(parent) = config.database_path.parent() {
std::fs::create_dir_all(parent)?;
}
std::fs::create_dir_all(&config.storage_path)?;
let database = Database::open(&config.database_path).await?;
database.set_storage_path(&config.storage_path).await?;
let retriever = Retriever::new().with_config_dir(&config.retrievers_path)?;
Ok(Learner { config, database, retriever })
}
}
impl Learner {
pub fn builder() -> LearnerBuilder { LearnerBuilder::new() }
pub async fn new() -> Result<Self> { Self::builder().build().await }
pub async fn from_path(path: impl AsRef<Path>) -> Result<Self> {
Self::builder().with_path(path).build().await
}
pub async fn with_config(config: Config) -> Result<Self> {
Self::builder().with_config(config).build().await
}
pub async fn init() -> Result<Self> { Self::with_config(Config::init()?).await }
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_learner_creation() {
let config_dir = tempdir().unwrap();
let database_dir = tempdir().unwrap();
let storage_dir = tempdir().unwrap();
let config = Config::default()
.with_database_path(&database_dir.path().join("learner.db"))
.with_retrievers_path(&config_dir.path().join("config/retrievers/"))
.with_storage_path(storage_dir.path());
let learner =
Learner::builder().with_path(config_dir.path()).with_config(config).build().await.unwrap();
assert_eq!(learner.config.retrievers_path, config_dir.path().join("config/retrievers/"));
assert_eq!(learner.config.database_path, database_dir.path().join("learner.db"));
assert_eq!(learner.database.get_storage_path().await.unwrap(), storage_dir.path());
}
}