ferro_ai/classifier/provider.rs
1use crate::error::Error;
2use async_trait::async_trait;
3
4use super::ClassifierConfig;
5
6/// Trait for AI classification backends.
7///
8/// Implement this trait to plug in any AI provider (Anthropic, OpenAI, local models).
9/// The single method receives raw prompts and a JSON schema, and returns a JSON value
10/// matching that schema.
11///
12/// # Object safety
13///
14/// This trait is object-safe and can be used as `Arc<dyn ClassificationProvider>`.
15///
16/// # Example
17///
18/// ```rust,ignore
19/// use ferro_ai::{ClassificationProvider, ClassifierConfig};
20/// use async_trait::async_trait;
21///
22/// struct EchoProvider;
23///
24/// #[async_trait]
25/// impl ClassificationProvider for EchoProvider {
26/// async fn classify_raw(
27/// &self,
28/// _system_prompt: &str,
29/// user_prompt: &str,
30/// _schema: &serde_json::Value,
31/// _config: &ClassifierConfig,
32/// ) -> Result<serde_json::Value, ferro_ai::Error> {
33/// Ok(serde_json::json!({"echo": user_prompt}))
34/// }
35/// }
36/// ```
37#[async_trait]
38pub trait ClassificationProvider: Send + Sync {
39 /// Call the AI provider with raw prompts and return structured JSON.
40 ///
41 /// The returned `serde_json::Value` must conform to the provided `schema`.
42 /// The schema is passed as a JSON Schema value; callers generate it via
43 /// `schemars::schema_for!(T)` or build it manually.
44 async fn classify_raw(
45 &self,
46 system_prompt: &str,
47 user_prompt: &str,
48 schema: &serde_json::Value,
49 config: &ClassifierConfig,
50 ) -> Result<serde_json::Value, Error>;
51}
52
53#[cfg(test)]
54mod tests {
55 use super::*;
56 use std::sync::Arc;
57
58 struct EchoProvider {
59 response: serde_json::Value,
60 }
61
62 #[async_trait]
63 impl ClassificationProvider for EchoProvider {
64 async fn classify_raw(
65 &self,
66 _system_prompt: &str,
67 _user_prompt: &str,
68 _schema: &serde_json::Value,
69 _config: &ClassifierConfig,
70 ) -> Result<serde_json::Value, Error> {
71 Ok(self.response.clone())
72 }
73 }
74
75 #[test]
76 fn test_classification_provider_is_object_safe() {
77 let provider = EchoProvider {
78 response: serde_json::json!({"result": "ok"}),
79 };
80 // This must compile — verifies object safety
81 let _: Arc<dyn ClassificationProvider> = Arc::new(provider);
82 }
83}