use async_trait::async_trait;
use glob::Pattern;
use notify::{
event::{CreateKind, ModifyKind},
Config, Event, EventKind, RecommendedWatcher, RecursiveMode, Watcher,
};
use std::path::PathBuf;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::{Duration, Instant};
use tokio::sync::broadcast;
use tracing::{debug, error, info, instrument};
use crate::{
document::Document,
loader::{
builtins::file_loaders::utils::{
extract_parent_dir, get_dirs_to_watch, parse_file, resolve_input_to_files,
},
Loader,
},
};
use super::{utils::load_initial, FileLoaderError};
const DEFAULT_CHANNEL_CAPACITY: usize = 20;
const DEBOUNCE_DURATION_MILLIS: u64 = 500;
#[allow(clippy::module_name_repetitions)]
#[derive(Debug)]
pub struct FileUpdatingLoaderBuilder {
glob_patterns: Vec<String>,
evaluated_patterns: Vec<Pattern>,
}
impl FileUpdatingLoaderBuilder {
#[instrument]
pub fn new(glob_patterns: Vec<String>) -> Result<Self, FileLoaderError> {
let evaluated_patterns: Vec<Pattern> = glob_patterns
.iter()
.map(|p| Pattern::new(p))
.collect::<Result<_, _>>()?;
info!(
"Successfully evaluated {} glob patterns",
evaluated_patterns.len()
);
Ok(Self {
glob_patterns,
evaluated_patterns,
})
}
#[instrument]
pub fn build(self) -> FileUpdatingLoader {
let files = resolve_input_to_files(
self.glob_patterns
.iter()
.map(std::string::String::as_str)
.collect(),
)
.unwrap();
let capacity = if files.is_empty() {
DEFAULT_CHANNEL_CAPACITY
} else {
files.len()
};
let (tx, _rx) = broadcast::channel(capacity);
debug!("broadcast channel with capacity: {} created", capacity);
FileUpdatingLoader {
patterns: self.evaluated_patterns,
tx,
sent: AtomicBool::new(false),
}
}
}
#[derive(Debug)]
pub struct FileUpdatingLoader {
tx: broadcast::Sender<Document>,
sent: AtomicBool,
patterns: Vec<Pattern>,
}
#[async_trait]
impl Loader for FileUpdatingLoader {
#[instrument(fields(self = format!("FileUpdatingLoader {{sent: {}}}", self.sent.load(Ordering::Acquire))))]
async fn subscribe(&self) -> broadcast::Receiver<Document> {
let receiver = self.tx.subscribe();
if !self.sent.load(Ordering::Acquire)
&& self
.sent
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
let initial_docs = load_initial(&self.patterns);
let mut sent_docs_count = 0;
let total_docs_count = initial_docs.len();
for doc in initial_docs {
if let Err(e) = self.tx.send(doc) {
error!("Loader failed to send document: {} to subscribers", e.0.id);
} else {
sent_docs_count += 1;
}
}
info!(
"Loader sent {} of {} initial documents to subscribers",
sent_docs_count, total_docs_count
);
let to_be_watched = get_dirs_to_watch(
&self
.patterns
.iter()
.map(|p| extract_parent_dir(p.as_str()))
.collect::<Vec<PathBuf>>(),
);
let txc = self.tx.clone();
let pc = self.patterns.clone();
tokio::spawn(async move {
let (evt_tx, evt_rx) = std::sync::mpsc::channel::<notify::Result<notify::Event>>();
let mut watcher = RecommendedWatcher::new(evt_tx, Config::default()).unwrap();
for path in &to_be_watched.clone() {
if let Err(e) = watcher.watch(path, RecursiveMode::Recursive) {
error!(
"Failed to add path {} to watcher: {}",
path.to_str().unwrap_or("unknown"),
e
);
} else {
info!(
"Added path {} to watcher",
path.to_str().unwrap_or("unknown")
);
}
}
let mut last_event_time = Instant::now();
let debounce_duration = Duration::from_millis(DEBOUNCE_DURATION_MILLIS);
loop {
let now = Instant::now();
if now.duration_since(last_event_time) >= debounce_duration {
for event in &evt_rx {
let event = event.unwrap();
let out = process_event(&event, &pc);
if out.is_none() {
continue;
}
let out = out.unwrap();
txc.send(document_for_event(out.0.as_str(), out.1)).unwrap();
last_event_time = now;
}
}
}
});
}
receiver
}
}
#[derive(Debug)]
enum EventType {
Modify,
Create,
Delete,
}
#[instrument]
fn document_for_event(path: &str, et: EventType) -> Document {
let file = std::path::Path::new(&path);
let data = match et {
EventType::Modify | EventType::Create => parse_file(file).unwrap(),
EventType::Delete => String::new(),
};
debug!("Created document for {} with event type {:?}", path, et);
Document {
id: path.to_string(),
data,
}
}
#[instrument]
fn process_event(event: &Event, patterns: &[Pattern]) -> Option<(String, EventType)> {
let path = event.paths.first()?.to_str()?;
if !patterns.iter().any(|p| p.matches(path)) {
debug!("Event path {} does not match any patterns", path);
return None;
}
match event.kind {
EventKind::Create(CreateKind::File) => Some((path.to_string(), EventType::Create)),
EventKind::Modify(ModifyKind::Data(_)) => Some((path.to_string(), EventType::Modify)),
EventKind::Remove(_) => Some((path.to_string(), EventType::Delete)),
_ => {
debug!("Unhandled event type {:?}", event.kind);
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use notify::{Event, EventKind};
use std::path::PathBuf;
use tempfile;
#[test]
fn test_process_event_matching_pattern() {
let pattern = Pattern::new("*.txt").unwrap();
let event = Event {
kind: EventKind::Create(CreateKind::File),
paths: vec![PathBuf::from("test.txt")],
attrs: Default::default(),
};
let result = process_event(&event, &[pattern]);
assert!(result.is_some());
let (path, et) = result.unwrap();
assert_eq!(path, "test.txt");
assert!(matches!(et, EventType::Create));
}
#[test]
fn test_process_event_non_matching_pattern() {
let pattern = Pattern::new("*.md").unwrap();
let event = Event {
kind: EventKind::Create(CreateKind::File),
paths: vec![PathBuf::from("test.txt")],
attrs: Default::default(),
};
let result = process_event(&event, &[pattern]);
assert!(result.is_none());
}
#[test]
fn test_document_for_event_create() {
let temp_dir = tempfile::tempdir().unwrap();
let file_path = temp_dir.path().join("test.txt");
std::fs::write(&file_path, "test content").unwrap();
let doc = document_for_event(file_path.to_str().unwrap(), EventType::Create);
assert_eq!(doc.id, file_path.to_str().unwrap());
assert_eq!(doc.data, "test content");
}
#[test]
fn test_document_for_event_delete() {
let doc = document_for_event("test.txt", EventType::Delete);
assert_eq!(doc.data, "");
}
#[tokio::test]
async fn test_initial_load_sends_documents() {
let temp_dir = tempfile::tempdir().unwrap();
let file_path = temp_dir.path().join("test.txt");
std::fs::write(&file_path, "initial").unwrap();
let builder = FileUpdatingLoaderBuilder::new(vec![temp_dir
.path()
.join("*.txt")
.to_str()
.unwrap()
.to_string()])
.unwrap();
let loader = builder.build();
let mut receiver = loader.subscribe().await;
let doc = receiver.recv().await.unwrap();
assert_eq!(doc.id, file_path.to_str().unwrap());
assert_eq!(doc.data, "initial");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_non_matching_files_ignored() {
let temp_dir = tempfile::tempdir().unwrap();
let matching_path = temp_dir.path().join("test.txt");
let non_matching_path = temp_dir.path().join("test.md");
std::fs::write(&matching_path, "initial").unwrap();
let builder = FileUpdatingLoaderBuilder::new(vec![temp_dir
.path()
.join("*.txt")
.to_str()
.unwrap()
.to_string()])
.unwrap();
let loader = builder.build();
let mut receiver = loader.subscribe().await;
assert!(receiver.recv().await.is_ok());
std::fs::write(&non_matching_path, "modified").unwrap();
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
assert!(receiver.try_recv().is_err());
}
}