Skip to main content

panini_lang_core/
component.rs

1use std::fmt::Debug;
2
3use serde::de::DeserializeOwned;
4
5use crate::aggregable::digest::AggregationSink;
6use crate::traits::LinguisticDefinition;
7
8/// A language the learner already speaks, with proficiency level.
9#[derive(Debug, Clone)]
10pub struct LanguageLevel {
11    pub iso_639_3: String,
12    pub level: String,
13}
14
15/// Context passed to components during schema/prompt generation.
16pub struct ComponentContext<'a> {
17    pub targets: &'a [String],
18    pub learner_ui_language: &'a str,
19    pub pedagogical_context: Option<&'a str>,
20    pub skill_path: Option<&'a str>,
21    pub linguistic_background: &'a [LanguageLevel],
22}
23
24/// A composable analysis component that contributes a section to the extraction schema,
25/// prompt, and output processing pipeline.
26///
27/// Each component owns one top-level key in the JSON output.
28/// Components are parameterized by the language definition `L` so they can
29/// access language-specific types and methods.
30pub trait AnalysisComponent<L: LinguisticDefinition>: Send + Sync + Debug {
31    /// Human-readable name for logging/display.
32    fn name(&self) -> &'static str;
33
34    /// The top-level JSON key this component produces (e.g. `"morphology"`).
35    fn schema_key(&self) -> &'static str;
36
37    /// Returns the JSON Schema fragment for this component's output.
38    /// This will be placed under `properties[schema_key]` in the composed schema.
39    fn schema_fragment(&self, lang: &L) -> serde_json::Value;
40
41    /// Returns prompt text describing what this component expects from the LLM.
42    fn prompt_fragment(&self, lang: &L, ctx: &ComponentContext) -> String;
43
44    /// Optional extra output instructions (appended to the output section).
45    fn output_instruction(&self) -> Option<&str> {
46        None
47    }
48
49    /// Pre-process the raw LLM JSON text before parsing.
50    /// Applied to the full JSON string; components are chained in order.
51    fn pre_process(&self, raw: &str) -> String {
52        raw.to_string()
53    }
54
55    /// Validate this component's section of the parsed JSON.
56    ///
57    /// # Errors
58    /// Returns a validation error string if the section does not conform to expected constraints.
59    fn validate(&self, _lang: &L, _section: &serde_json::Value) -> Result<(), String> {
60        Ok(())
61    }
62
63    /// Post-process this component's section of the parsed JSON (in place).
64    ///
65    /// # Errors
66    /// Returns an error string if post-processing logic fails.
67    fn post_process(&self, _lang: &L, _section: &mut serde_json::Value) -> Result<(), String> {
68        Ok(())
69    }
70
71    /// Whether this component is compatible with the given language.
72    /// Incompatible components are silently skipped.
73    fn is_compatible(&self, _lang: &L) -> bool {
74        true
75    }
76
77    /// Returns `Some(self)` for components that produce aggregable data.
78    ///
79    /// Override to return `Some(self)` in components that implement [`Aggregating<L>`].
80    /// Default returns `None` — non-aggregable components carry no aggregation logic.
81    fn as_aggregating(&self) -> Option<&dyn Aggregating<L>> {
82        None
83    }
84}
85
86// ─── ExtractionResult ────────────────────────────────────────────────────────
87
88/// Error type for `ExtractionResult` accessor methods.
89#[derive(Debug, thiserror::Error)]
90pub enum ExtractionResultError {
91    #[error("key not found: {key}")]
92    KeyNotFound { key: String },
93    #[error("deserialization error for key '{key}': {source}")]
94    DeserializeError {
95        key: String,
96        source: serde_json::Error,
97    },
98}
99
100/// Container for the composed extraction result.
101///
102/// Holds the raw JSON value (an object with one key per component)
103/// and provides typed accessors.
104#[derive(Debug, Clone)]
105pub struct ExtractionResult {
106    raw: serde_json::Value,
107    requested_keys: Vec<&'static str>,
108}
109
110impl ExtractionResult {
111    /// Create a new `ExtractionResult` from a raw JSON object and the list
112    /// of component keys that were requested.
113    #[must_use]
114    pub const fn new(raw: serde_json::Value, requested_keys: Vec<&'static str>) -> Self {
115        Self {
116            raw,
117            requested_keys,
118        }
119    }
120
121    /// Deserialize a component's section into a concrete type.
122    ///
123    /// # Errors
124    /// Returns `ExtractionResultError::KeyNotFound` if the key is not in the result.
125    /// Returns `ExtractionResultError::DeserializeError` if the section fails to deserialize into `T`.
126    pub fn get<T: DeserializeOwned>(&self, key: &str) -> Result<T, ExtractionResultError> {
127        let section = self
128            .raw
129            .get(key)
130            .ok_or_else(|| ExtractionResultError::KeyNotFound {
131                key: key.to_string(),
132            })?;
133        serde_json::from_value(section.clone()).map_err(|e| {
134            ExtractionResultError::DeserializeError {
135                key: key.to_string(),
136                source: e,
137            }
138        })
139    }
140
141    /// Get the raw JSON value for a component's key.
142    #[must_use]
143    pub fn get_raw(&self, key: &str) -> Option<&serde_json::Value> {
144        self.raw.get(key)
145    }
146
147    /// Iterate over all (key, value) pairs in the raw JSON object.
148    pub fn iter_raw(&self) -> impl Iterator<Item = (&str, &serde_json::Value)> {
149        self.raw
150            .as_object()
151            .into_iter()
152            .flat_map(|obj| obj.iter().map(|(k, v)| (k.as_str(), v)))
153    }
154
155    /// The keys that were requested (i.e., the compatible components).
156    #[must_use]
157    pub fn requested_keys(&self) -> &[&'static str] {
158        &self.requested_keys
159    }
160
161    /// Consume and return the raw JSON value.
162    #[must_use]
163    pub fn into_raw(self) -> serde_json::Value {
164        self.raw
165    }
166}
167
168// ─── AggregationError ────────────────────────────────────────────────────────
169
170/// Typed error for [`Aggregating::aggregate_section`].
171#[derive(Debug, thiserror::Error)]
172pub enum AggregationError {
173    #[error("failed to deserialize section '{key}': {source}")]
174    Deserialize {
175        key: &'static str,
176        #[source]
177        source: serde_json::Error,
178    },
179    #[error("aggregation hook for '{key}' failed: {message}")]
180    Hook { key: &'static str, message: String },
181}
182
183// ─── Aggregating sub-trait ────────────────────────────────────────────────────
184
185/// Extension of [`AnalysisComponent`] for components that produce aggregable data.
186///
187/// Components opt in by overriding `as_aggregating` on `AnalysisComponent` to
188/// return `Some(self)`. Non-aggregable components (`PedagogicalExplanation`,
189/// `LeipzigAlignment`, etc.) do nothing.
190pub trait Aggregating<L: LinguisticDefinition>: AnalysisComponent<L> {
191    /// Project this component's JSON section into aggregation contributions.
192    ///
193    /// Called per-card with the deserialized (and post-processed) section value.
194    /// Implementations deserialize the section and push contributions to `sink`
195    /// via [`AggregationSink::record_contribution`] or the typed shim
196    /// [`AggregationSink::record`].
197    fn aggregate_section(
198        &self,
199        lang: &L,
200        section: &serde_json::Value,
201        sink: &mut dyn AggregationSink,
202    ) -> Result<(), AggregationError>;
203}
204
205// ─── Marker trait ─────────────────────────────────────────────────────────────
206
207/// Marker trait for compile-time validation of component-language compatibility.
208///
209/// Used by `#[derive(PaniniResult)]` to enforce that a component is valid for
210/// the language `L`. Universal components implement this for all `L: LinguisticDefinition`.
211/// Restricted components (e.g. `MorphemeSegmentation`) add trait bounds
212/// (e.g. `L: Agglutinative`), causing a compile error if used with an incompatible language.
213pub trait ComponentRequires<L> {}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218
219    #[test]
220    fn get_typed_value() {
221        let raw = serde_json::json!({
222            "pedagogical_explanation": "This is a test.",
223            "morphology": { "target_features": [], "context_features": [] }
224        });
225        let result = ExtractionResult::new(raw, vec!["pedagogical_explanation", "morphology"]);
226
227        let explanation: String = result.get("pedagogical_explanation").unwrap();
228        assert_eq!(explanation, "This is a test.");
229    }
230
231    #[test]
232    fn get_missing_key_returns_key_not_found() {
233        let raw = serde_json::json!({ "morphology": {} });
234        let result = ExtractionResult::new(raw, vec!["morphology"]);
235
236        let err = result.get::<String>("nonexistent").unwrap_err();
237        assert!(matches!(err, ExtractionResultError::KeyNotFound { .. }));
238    }
239
240    #[test]
241    fn get_raw_returns_section() {
242        let raw = serde_json::json!({ "morphology": { "target_features": [] } });
243        let result = ExtractionResult::new(raw, vec!["morphology"]);
244
245        assert!(result.get_raw("morphology").is_some());
246        assert!(result.get_raw("nonexistent").is_none());
247    }
248
249    #[test]
250    fn iter_raw_returns_all_entries() {
251        let raw = serde_json::json!({
252            "a": 1,
253            "b": 2,
254            "c": 3
255        });
256        let result = ExtractionResult::new(raw, vec![]);
257
258        let keys: Vec<&str> = result.iter_raw().map(|(k, _)| k).collect();
259        assert_eq!(keys.len(), 3);
260        assert!(keys.contains(&"a"));
261        assert!(keys.contains(&"b"));
262        assert!(keys.contains(&"c"));
263    }
264
265    #[test]
266    fn into_raw_consumes() {
267        let raw = serde_json::json!({ "key": "value" });
268        let result = ExtractionResult::new(raw.clone(), vec!["key"]);
269        assert_eq!(result.into_raw(), raw);
270    }
271}