Skip to main content

ferro_ai/classifier/
provider.rs

1use crate::error::Error;
2use async_trait::async_trait;
3
4use super::ClassifierConfig;
5
6/// Trait for AI classification backends.
7///
8/// Implement this trait to plug in any AI provider (Anthropic, OpenAI, local models).
9/// The single method receives raw prompts and a JSON schema, and returns a JSON value
10/// matching that schema.
11///
12/// # Object safety
13///
14/// This trait is object-safe and can be used as `Arc<dyn ClassificationProvider>`.
15///
16/// # Example
17///
18/// ```rust,ignore
19/// use ferro_ai::{ClassificationProvider, ClassifierConfig};
20/// use async_trait::async_trait;
21///
22/// struct EchoProvider;
23///
24/// #[async_trait]
25/// impl ClassificationProvider for EchoProvider {
26///     async fn classify_raw(
27///         &self,
28///         _system_prompt: &str,
29///         user_prompt: &str,
30///         _schema: &serde_json::Value,
31///         _config: &ClassifierConfig,
32///     ) -> Result<serde_json::Value, ferro_ai::Error> {
33///         Ok(serde_json::json!({"echo": user_prompt}))
34///     }
35/// }
36/// ```
37#[async_trait]
38pub trait ClassificationProvider: Send + Sync {
39    /// Call the AI provider with raw prompts and return structured JSON.
40    ///
41    /// The returned `serde_json::Value` must conform to the provided `schema`.
42    /// The schema is passed as a JSON Schema value; callers generate it via
43    /// `schemars::schema_for!(T)` or build it manually.
44    async fn classify_raw(
45        &self,
46        system_prompt: &str,
47        user_prompt: &str,
48        schema: &serde_json::Value,
49        config: &ClassifierConfig,
50    ) -> Result<serde_json::Value, Error>;
51}
52
53#[cfg(test)]
54mod tests {
55    use super::*;
56    use std::sync::Arc;
57
58    struct EchoProvider {
59        response: serde_json::Value,
60    }
61
62    #[async_trait]
63    impl ClassificationProvider for EchoProvider {
64        async fn classify_raw(
65            &self,
66            _system_prompt: &str,
67            _user_prompt: &str,
68            _schema: &serde_json::Value,
69            _config: &ClassifierConfig,
70        ) -> Result<serde_json::Value, Error> {
71            Ok(self.response.clone())
72        }
73    }
74
75    #[test]
76    fn test_classification_provider_is_object_safe() {
77        let provider = EchoProvider {
78            response: serde_json::json!({"result": "ok"}),
79        };
80        // This must compile — verifies object safety
81        let _: Arc<dyn ClassificationProvider> = Arc::new(provider);
82    }
83}