use std::{collections::HashMap, sync::Arc};
use {reovim_driver_session::SessionExtension, reovim_kernel::api::v1::BufferId};
use crate::{LanguageRegistry, SyntaxDriver, SyntaxDriverFactory};
#[derive(Default)]
pub struct SyntaxSessionState {
drivers: HashMap<usize, Box<dyn SyntaxDriver>>,
factory: Option<Arc<dyn SyntaxDriverFactory>>,
registry: Option<Arc<dyn LanguageRegistry>>,
}
impl SessionExtension for SyntaxSessionState {
fn create() -> Self {
Self::default()
}
}
impl SyntaxSessionState {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn set_factory(&mut self, factory: Arc<dyn SyntaxDriverFactory>) {
self.factory = Some(factory);
}
#[must_use]
pub fn factory(&self) -> Option<&dyn SyntaxDriverFactory> {
self.factory.as_deref()
}
pub fn set_registry(&mut self, registry: Arc<dyn LanguageRegistry>) {
self.registry = Some(registry);
}
#[must_use]
pub fn registry(&self) -> Option<&dyn LanguageRegistry> {
self.registry.as_deref()
}
#[must_use]
pub fn detect_language(&self, path: &str) -> Option<String> {
self.registry.as_ref()?.detect_from_path(path)
}
pub fn ensure_driver_from_path(
&mut self,
buffer_id: BufferId,
path: &str,
content: &str,
) -> bool {
if self.drivers.contains_key(&buffer_id.as_usize()) {
return true;
}
let Some(language_id) = self.detect_language(path) else {
return false;
};
self.ensure_driver(buffer_id, &language_id, content)
}
#[must_use]
pub fn get(&self, buffer_id: BufferId) -> Option<&dyn SyntaxDriver> {
self.drivers
.get(&buffer_id.as_usize())
.map(|d| &**d as &dyn SyntaxDriver)
}
pub fn get_mut(&mut self, buffer_id: BufferId) -> Option<&mut dyn SyntaxDriver> {
self.drivers
.get_mut(&buffer_id.as_usize())
.map(|d| &mut **d as &mut dyn SyntaxDriver)
}
pub fn set(&mut self, buffer_id: BufferId, driver: Box<dyn SyntaxDriver>) {
self.drivers.insert(buffer_id.as_usize(), driver);
}
pub fn remove(&mut self, buffer_id: BufferId) -> Option<Box<dyn SyntaxDriver>> {
self.drivers.remove(&buffer_id.as_usize())
}
#[must_use]
pub fn has_driver(&self, buffer_id: BufferId) -> bool {
self.drivers.contains_key(&buffer_id.as_usize())
}
pub fn ensure_driver(&mut self, buffer_id: BufferId, language_id: &str, content: &str) -> bool {
if self.drivers.contains_key(&buffer_id.as_usize()) {
return true;
}
if let Some(factory) = &self.factory
&& let Some(mut driver) = factory.create(language_id)
{
driver.set_injection_factory(factory.clone());
driver.parse(content);
self.drivers.insert(buffer_id.as_usize(), driver);
return true;
}
false
}
#[must_use]
pub fn len(&self) -> usize {
self.drivers.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.drivers.is_empty()
}
pub fn clear(&mut self) {
self.drivers.clear();
}
}
impl std::fmt::Debug for SyntaxSessionState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SyntaxSessionState")
.field("buffer_count", &self.drivers.len())
.field("has_factory", &self.factory.is_some())
.field("has_registry", &self.registry.is_some())
.finish()
}
}
#[cfg(test)]
#[path = "state_tests.rs"]
mod tests;