Skip to main content

teaql_runtime/
context.rs

1use std::any::{Any, TypeId};
2use std::collections::{BTreeMap, HashMap};
3use std::sync::Mutex;
4
5use teaql_core::{EntityDescriptor, Record, UpdateCommand, Value};
6use teaql_sql::{CompiledQuery, DatabaseKind, SqlDialect};
7
8use crate::{
9    CheckResults, CheckerRegistry, ContextError, EntityEvent, EntityEventSink, GraphNode,
10    InternalIdGenerator, Language, MetadataStore, ObjectLocation, RepositoryBehavior,
11    RepositoryBehaviorRegistry, RepositoryRegistry, RuntimeError, local_id_generator,
12    translate_check_result,
13};
14use crate::{EntityRoot, QueryExecutor, RepositoryError};
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum SqlLogOperation {
18    Select,
19    Insert,
20    Update,
21    Delete,
22    Recover,
23}
24
25impl SqlLogOperation {
26    pub fn is_select(self) -> bool {
27        matches!(self, Self::Select)
28    }
29
30    pub fn is_mutation(self) -> bool {
31        !self.is_select()
32    }
33}
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
36pub struct SqlLogOptions {
37    pub select: bool,
38    pub mutation: bool,
39}
40
41impl SqlLogOptions {
42    pub fn select_only() -> Self {
43        Self {
44            select: true,
45            mutation: false,
46        }
47    }
48
49    pub fn mutation_only() -> Self {
50        Self {
51            select: false,
52            mutation: true,
53        }
54    }
55
56    pub fn all() -> Self {
57        Self {
58            select: true,
59            mutation: true,
60        }
61    }
62
63    pub fn enabled_for(self, operation: SqlLogOperation) -> bool {
64        if operation.is_select() {
65            self.select
66        } else {
67            self.mutation
68        }
69    }
70}
71
72#[derive(Debug, Clone, PartialEq)]
73pub struct SqlLogEntry {
74    pub operation: SqlLogOperation,
75    pub sql: String,
76    pub params: Vec<Value>,
77    pub debug_sql: String,
78}
79
80#[derive(Default)]
81pub struct UserContext {
82    pub(crate) metadata: Option<Box<dyn MetadataStore>>,
83    pub(crate) repository_registry: Option<Box<dyn RepositoryRegistry>>,
84    pub(crate) repository_behavior_registry: Option<Box<dyn RepositoryBehaviorRegistry>>,
85    pub(crate) checker_registry: Option<Box<dyn CheckerRegistry>>,
86    pub(crate) event_sink: Option<Box<dyn EntityEventSink>>,
87    pub(crate) internal_id_generator: Option<Box<dyn InternalIdGenerator>>,
88    language: Language,
89    typed_resources: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
90    named_resources: BTreeMap<String, Box<dyn Any + Send + Sync>>,
91    locals: BTreeMap<String, Value>,
92    pub(crate) initial_graphs: Vec<GraphNode>,
93    entity_root: EntityRoot,
94    sql_log_options: SqlLogOptions,
95    sql_log_entries: Mutex<Vec<SqlLogEntry>>,
96}
97
98impl UserContext {
99    pub fn new() -> Self {
100        Self::default()
101    }
102
103    pub fn with_module(mut self, module: crate::RuntimeModule) -> Self {
104        module.apply_to(&mut self);
105        self
106    }
107
108    pub fn entity_root(&self) -> EntityRoot {
109        self.entity_root.clone()
110    }
111
112    pub fn initial_graphs(&self) -> &[GraphNode] {
113        &self.initial_graphs
114    }
115
116    pub fn set_initial_graphs(&mut self, graphs: Vec<GraphNode>) {
117        self.initial_graphs = graphs;
118    }
119
120    pub fn with_metadata(mut self, metadata: impl MetadataStore + 'static) -> Self {
121        self.metadata = Some(Box::new(metadata));
122        self
123    }
124
125    pub fn set_metadata(&mut self, metadata: impl MetadataStore + 'static) {
126        self.metadata = Some(Box::new(metadata));
127    }
128
129    pub fn with_repository_registry(mut self, registry: impl RepositoryRegistry + 'static) -> Self {
130        self.repository_registry = Some(Box::new(registry));
131        self
132    }
133
134    pub fn set_repository_registry(&mut self, registry: impl RepositoryRegistry + 'static) {
135        self.repository_registry = Some(Box::new(registry));
136    }
137
138    pub fn with_repository_behavior_registry(
139        mut self,
140        registry: impl RepositoryBehaviorRegistry + 'static,
141    ) -> Self {
142        self.repository_behavior_registry = Some(Box::new(registry));
143        self
144    }
145
146    pub fn set_repository_behavior_registry(
147        &mut self,
148        registry: impl RepositoryBehaviorRegistry + 'static,
149    ) {
150        self.repository_behavior_registry = Some(Box::new(registry));
151    }
152
153    pub fn with_checker_registry(mut self, registry: impl CheckerRegistry + 'static) -> Self {
154        self.checker_registry = Some(Box::new(registry));
155        self
156    }
157
158    pub fn set_checker_registry(&mut self, registry: impl CheckerRegistry + 'static) {
159        self.checker_registry = Some(Box::new(registry));
160    }
161
162    pub fn with_event_sink(mut self, sink: impl EntityEventSink + 'static) -> Self {
163        self.event_sink = Some(Box::new(sink));
164        self
165    }
166
167    pub fn set_event_sink(&mut self, sink: impl EntityEventSink + 'static) {
168        self.event_sink = Some(Box::new(sink));
169    }
170
171    pub fn with_internal_id_generator(
172        mut self,
173        generator: impl InternalIdGenerator + 'static,
174    ) -> Self {
175        self.internal_id_generator = Some(Box::new(generator));
176        self
177    }
178
179    pub fn set_internal_id_generator(&mut self, generator: impl InternalIdGenerator + 'static) {
180        self.internal_id_generator = Some(Box::new(generator));
181    }
182
183    pub fn with_language(mut self, language: Language) -> Self {
184        self.language = language;
185        self
186    }
187
188    pub fn set_language(&mut self, language: Language) {
189        self.language = language;
190    }
191
192    pub fn with_sql_log_options(mut self, options: SqlLogOptions) -> Self {
193        self.sql_log_options = options;
194        self
195    }
196
197    pub fn set_sql_log_options(&mut self, options: SqlLogOptions) {
198        self.sql_log_options = options;
199    }
200
201    pub fn enable_select_sql_log(&mut self) {
202        self.sql_log_options.select = true;
203    }
204
205    pub fn enable_mutation_sql_log(&mut self) {
206        self.sql_log_options.mutation = true;
207    }
208
209    pub fn enable_all_sql_log(&mut self) {
210        self.sql_log_options = SqlLogOptions::all();
211    }
212
213    pub fn disable_sql_log(&mut self) {
214        self.sql_log_options = SqlLogOptions::default();
215        self.clear_sql_logs();
216    }
217
218    pub fn sql_log_options(&self) -> SqlLogOptions {
219        self.sql_log_options
220    }
221
222    pub fn sql_logs(&self) -> Vec<SqlLogEntry> {
223        self.sql_log_entries
224            .lock()
225            .map(|entries| entries.clone())
226            .unwrap_or_default()
227    }
228
229    pub fn clear_sql_logs(&self) {
230        if let Ok(mut entries) = self.sql_log_entries.lock() {
231            entries.clear();
232        }
233    }
234
235    pub(crate) fn record_sql_log(
236        &self,
237        operation: SqlLogOperation,
238        query: &CompiledQuery,
239        database_kind: DatabaseKind,
240    ) {
241        if !self.sql_log_options.enabled_for(operation) {
242            return;
243        }
244        if let Ok(mut entries) = self.sql_log_entries.lock() {
245            entries.push(SqlLogEntry {
246                operation,
247                sql: query.sql.clone(),
248                params: query.params.clone(),
249                debug_sql: query.debug_sql(database_kind),
250            });
251        }
252    }
253
254    pub fn language(&self) -> Language {
255        self.language
256    }
257
258    pub fn set_language_code(&mut self, code: &str) -> Result<(), RuntimeError> {
259        let Some(language) = Language::from_code(code) else {
260            return Err(RuntimeError::Language(format!(
261                "unsupported language code: {code}"
262            )));
263        };
264        self.language = language;
265        Ok(())
266    }
267
268    pub fn generate_id(&self, entity: &str) -> Result<Option<u64>, RuntimeError> {
269        self.internal_id_generator
270            .as_ref()
271            .map(|generator| generator.generate_id(entity))
272            .transpose()
273    }
274
275    pub fn next_id(&self, entity: &str) -> Result<u64, RuntimeError> {
276        match self.generate_id(entity)? {
277            Some(id) => Ok(id),
278            None => local_id_generator().generate_id(entity),
279        }
280    }
281
282    pub fn entity(&self, name: &str) -> Option<&EntityDescriptor> {
283        self.metadata
284            .as_ref()
285            .and_then(|metadata| metadata.entity(name))
286    }
287
288    pub fn require_entity(&self, name: &str) -> Result<&EntityDescriptor, RuntimeError> {
289        self.entity(name)
290            .ok_or_else(|| RuntimeError::MissingEntity(name.to_owned()))
291    }
292
293    pub fn insert_resource<T>(&mut self, resource: T)
294    where
295        T: Send + Sync + 'static,
296    {
297        self.typed_resources
298            .insert(TypeId::of::<T>(), Box::new(resource));
299    }
300
301    pub fn get_resource<T>(&self) -> Option<&T>
302    where
303        T: Send + Sync + 'static,
304    {
305        self.typed_resources
306            .get(&TypeId::of::<T>())
307            .and_then(|value| value.downcast_ref::<T>())
308    }
309
310    pub fn require_resource<T>(&self) -> Result<&T, ContextError>
311    where
312        T: Send + Sync + 'static,
313    {
314        self.get_resource::<T>()
315            .ok_or(ContextError::MissingTypedResource(
316                std::any::type_name::<T>(),
317            ))
318    }
319
320    pub fn insert_named_resource<T>(&mut self, name: impl Into<String>, resource: T)
321    where
322        T: Send + Sync + 'static,
323    {
324        self.named_resources.insert(name.into(), Box::new(resource));
325    }
326
327    pub fn get_named_resource<T>(&self, name: &str) -> Option<&T>
328    where
329        T: Send + Sync + 'static,
330    {
331        self.named_resources
332            .get(name)
333            .and_then(|value| value.downcast_ref::<T>())
334    }
335
336    pub fn require_named_resource<T>(&self, name: &str) -> Result<&T, ContextError>
337    where
338        T: Send + Sync + 'static,
339    {
340        self.get_named_resource::<T>(name)
341            .ok_or_else(|| ContextError::MissingResource(name.to_owned()))
342    }
343
344    pub fn put_local(&mut self, key: impl Into<String>, value: impl Into<Value>) {
345        self.locals.insert(key.into(), value.into());
346    }
347
348    pub fn local(&self, key: &str) -> Option<&Value> {
349        self.locals.get(key)
350    }
351
352    pub fn remove_local(&mut self, key: &str) -> Option<Value> {
353        self.locals.remove(key)
354    }
355
356    pub fn has_repository(&self, entity: &str) -> bool {
357        let in_registry = self
358            .repository_registry
359            .as_ref()
360            .map(|registry| registry.contains(entity))
361            .unwrap_or(false);
362        in_registry || self.entity(entity).is_some()
363    }
364
365    pub fn repository_behavior(
366        &self,
367        entity: &str,
368    ) -> Option<std::sync::Arc<dyn RepositoryBehavior>> {
369        self.repository_behavior_registry
370            .as_ref()
371            .and_then(|registry| registry.behavior(entity))
372    }
373
374    pub fn has_checker(&self, entity: &str) -> bool {
375        self.checker_registry
376            .as_ref()
377            .and_then(|registry| registry.checker(entity))
378            .is_some()
379    }
380
381    pub fn check_and_fix_record(
382        &self,
383        entity: &str,
384        record: &mut Record,
385    ) -> Result<(), RuntimeError> {
386        self.check_and_fix_record_at(entity, record, &ObjectLocation::root())
387    }
388
389    pub fn check_and_fix_record_at(
390        &self,
391        entity: &str,
392        record: &mut Record,
393        location: &ObjectLocation,
394    ) -> Result<(), RuntimeError> {
395        let Some(checker) = self
396            .checker_registry
397            .as_ref()
398            .and_then(|registry| registry.checker(entity))
399        else {
400            return Ok(());
401        };
402        let mut results = CheckResults::new();
403        checker.check_and_fix(self, record, location, &mut results);
404        if results.is_empty() {
405            Ok(())
406        } else {
407            self.translate_check_results(&mut results);
408            Err(RuntimeError::Check(results))
409        }
410    }
411
412    pub fn translate_check_results(&self, results: &mut CheckResults) {
413        for result in results {
414            result.message = Some(translate_check_result(self.language, result));
415        }
416    }
417
418    pub fn send_event(&self, event: EntityEvent) -> Result<(), RuntimeError> {
419        let Some(sink) = self.event_sink.as_ref() else {
420            return Ok(());
421        };
422        sink.on_event(self, &event)
423    }
424
425    pub fn commit_changes<D, E>(&self) -> Result<(), RepositoryError<E::Error>>
426    where
427        D: SqlDialect + Send + Sync + 'static,
428        E: QueryExecutor + Send + Sync + 'static,
429    {
430        let dialect = self.require_resource::<D>().map_err(|err| {
431            RepositoryError::Runtime(RuntimeError::Graph(format!(
432                "cannot commit changes without dialect: {err}"
433            )))
434        })?;
435        let executor = self.require_resource::<E>().map_err(|err| {
436            RepositoryError::Runtime(RuntimeError::Graph(format!(
437                "cannot commit changes without executor: {err}"
438            )))
439        })?;
440        let change_set = self.entity_root.current_change_set();
441
442        for (key, changes) in change_set.changes() {
443            if changes.is_empty() {
444                continue;
445            }
446            let entity = self
447                .require_entity(&key.entity)
448                .map_err(RepositoryError::Runtime)?;
449            let mut command = UpdateCommand::new(&key.entity, key.id.clone());
450            for (field, value) in changes {
451                command = command.value(field.clone(), value.clone());
452            }
453            let query = dialect
454                .compile_update(entity, &command)
455                .map_err(RuntimeError::from)
456                .map_err(RepositoryError::Runtime)?;
457            executor
458                .execute(&query)
459                .map_err(RepositoryError::Executor)?;
460        }
461
462        self.entity_root.clear_current_change_set();
463        Ok(())
464    }
465}