Skip to main content

rs_adk/processors/
mod.rs

1//! Request/response processors — middleware for LLM request pipelines.
2//!
3//! Unlike ADK-JS where processors are baked into the LLM request pipeline,
4//! our processors compose as middleware — they work with any `BaseLlm`,
5//! not just Gemini.
6
7use async_trait::async_trait;
8
9use crate::llm::{LlmRequest, LlmResponse};
10
11/// Errors from processor operations.
12#[derive(Debug, thiserror::Error)]
13pub enum ProcessorError {
14    /// An error during request processing.
15    #[error("Processor error: {0}")]
16    Processing(String),
17}
18
19/// Trait for processing LLM requests before they are sent.
20#[async_trait]
21pub trait RequestProcessor: Send + Sync {
22    /// Processor name for logging/debugging.
23    fn name(&self) -> &str;
24
25    /// Process the request, potentially modifying it.
26    async fn process_request(&self, request: LlmRequest) -> Result<LlmRequest, ProcessorError>;
27}
28
29/// Trait for processing LLM responses after they are received.
30#[async_trait]
31pub trait ResponseProcessor: Send + Sync {
32    /// Processor name for logging/debugging.
33    fn name(&self) -> &str;
34
35    /// Process the response, potentially modifying it.
36    async fn process_response(&self, response: LlmResponse) -> Result<LlmResponse, ProcessorError>;
37}
38
39/// Processor that prepends a system instruction to every request.
40pub struct InstructionInserter {
41    instruction: String,
42}
43
44impl InstructionInserter {
45    /// Create a new instruction inserter.
46    pub fn new(instruction: impl Into<String>) -> Self {
47        Self {
48            instruction: instruction.into(),
49        }
50    }
51}
52
53#[async_trait]
54impl RequestProcessor for InstructionInserter {
55    fn name(&self) -> &str {
56        "instruction_inserter"
57    }
58
59    async fn process_request(&self, mut request: LlmRequest) -> Result<LlmRequest, ProcessorError> {
60        match &mut request.system_instruction {
61            Some(existing) => {
62                existing.push('\n');
63                existing.push_str(&self.instruction);
64            }
65            None => {
66                request.system_instruction = Some(self.instruction.clone());
67            }
68        }
69        Ok(request)
70    }
71}
72
73/// Processor that filters content parts, keeping only those that match a predicate.
74pub struct ContentFilter {
75    /// Keep only text parts.
76    text_only: bool,
77}
78
79impl ContentFilter {
80    /// Create a filter that keeps only text parts.
81    pub fn text_only() -> Self {
82        Self { text_only: true }
83    }
84}
85
86#[async_trait]
87impl RequestProcessor for ContentFilter {
88    fn name(&self) -> &str {
89        "content_filter"
90    }
91
92    async fn process_request(&self, mut request: LlmRequest) -> Result<LlmRequest, ProcessorError> {
93        if self.text_only {
94            for content in &mut request.contents {
95                content
96                    .parts
97                    .retain(|p| matches!(p, rs_genai::prelude::Part::Text { .. }));
98            }
99        }
100        Ok(request)
101    }
102}
103
104/// An ordered chain of request processors.
105#[derive(Default)]
106pub struct RequestProcessorChain {
107    processors: Vec<Box<dyn RequestProcessor>>,
108}
109
110impl RequestProcessorChain {
111    /// Create an empty chain.
112    pub fn new() -> Self {
113        Self::default()
114    }
115
116    /// Add a processor to the end of the chain.
117    pub fn add(&mut self, processor: impl RequestProcessor + 'static) {
118        self.processors.push(Box::new(processor));
119    }
120
121    /// Process a request through all processors in order.
122    pub async fn process(&self, mut request: LlmRequest) -> Result<LlmRequest, ProcessorError> {
123        for processor in &self.processors {
124            request = processor.process_request(request).await?;
125        }
126        Ok(request)
127    }
128
129    /// Number of processors in the chain.
130    pub fn len(&self) -> usize {
131        self.processors.len()
132    }
133
134    /// Returns true if chain is empty.
135    pub fn is_empty(&self) -> bool {
136        self.processors.is_empty()
137    }
138}
139
140/// An ordered chain of response processors.
141#[derive(Default)]
142pub struct ResponseProcessorChain {
143    processors: Vec<Box<dyn ResponseProcessor>>,
144}
145
146impl ResponseProcessorChain {
147    /// Create an empty chain.
148    pub fn new() -> Self {
149        Self::default()
150    }
151
152    /// Add a processor to the end of the chain.
153    pub fn add(&mut self, processor: impl ResponseProcessor + 'static) {
154        self.processors.push(Box::new(processor));
155    }
156
157    /// Process a response through all processors in order.
158    pub async fn process(&self, mut response: LlmResponse) -> Result<LlmResponse, ProcessorError> {
159        for processor in &self.processors {
160            response = processor.process_response(response).await?;
161        }
162        Ok(response)
163    }
164
165    /// Number of processors in the chain.
166    pub fn len(&self) -> usize {
167        self.processors.len()
168    }
169
170    /// Returns true if chain is empty.
171    pub fn is_empty(&self) -> bool {
172        self.processors.is_empty()
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179    use crate::llm::LlmRequest;
180
181    #[test]
182    fn request_processor_is_object_safe() {
183        fn _assert(_: &dyn RequestProcessor) {}
184    }
185
186    #[test]
187    fn response_processor_is_object_safe() {
188        fn _assert(_: &dyn ResponseProcessor) {}
189    }
190
191    #[tokio::test]
192    async fn instruction_inserter_sets_instruction() {
193        let inserter = InstructionInserter::new("Be helpful");
194        let req = LlmRequest::from_text("Hello");
195        let processed = inserter.process_request(req).await.unwrap();
196        assert_eq!(processed.system_instruction, Some("Be helpful".into()));
197    }
198
199    #[tokio::test]
200    async fn instruction_inserter_appends_to_existing() {
201        let inserter = InstructionInserter::new("And concise");
202        let mut req = LlmRequest::from_text("Hello");
203        req.system_instruction = Some("Be helpful".into());
204        let processed = inserter.process_request(req).await.unwrap();
205        assert_eq!(
206            processed.system_instruction,
207            Some("Be helpful\nAnd concise".into())
208        );
209    }
210
211    #[tokio::test]
212    async fn content_filter_text_only() {
213        use rs_genai::prelude::{Content, Part, Role};
214
215        let filter = ContentFilter::text_only();
216        let req = LlmRequest {
217            contents: vec![Content {
218                role: Some(Role::User),
219                parts: vec![
220                    Part::Text {
221                        text: "hello".into(),
222                    },
223                    Part::InlineData {
224                        inline_data: rs_genai::prelude::Blob {
225                            mime_type: "image/png".into(),
226                            data: "base64data".into(),
227                        },
228                    },
229                ],
230            }],
231            ..Default::default()
232        };
233        let processed = filter.process_request(req).await.unwrap();
234        assert_eq!(processed.contents[0].parts.len(), 1);
235        assert!(matches!(&processed.contents[0].parts[0], Part::Text { .. }));
236    }
237
238    #[tokio::test]
239    async fn request_processor_chain() {
240        let mut chain = RequestProcessorChain::new();
241        chain.add(InstructionInserter::new("Rule 1"));
242        chain.add(InstructionInserter::new("Rule 2"));
243
244        let req = LlmRequest::from_text("Hello");
245        let processed = chain.process(req).await.unwrap();
246        assert_eq!(processed.system_instruction, Some("Rule 1\nRule 2".into()));
247    }
248
249    #[test]
250    fn chain_len() {
251        let mut chain = RequestProcessorChain::new();
252        assert!(chain.is_empty());
253        chain.add(InstructionInserter::new("x"));
254        assert_eq!(chain.len(), 1);
255        assert!(!chain.is_empty());
256    }
257}