ecl_workflows/
critique_loop.rs

1//! Critique-Revise workflow with bounded feedback loop.
2
3use ecl_core::llm::{CompletionRequest, LlmProvider, Message};
4use ecl_core::{CritiqueDecision, Error, Result, WorkflowId};
5use serde::{Deserialize, Serialize};
6use std::sync::Arc;
7
8/// Maximum number of revision attempts before giving up.
9const MAX_REVISIONS: u32 = 3;
10
11/// Input for critique-revise workflow.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct CritiqueLoopInput {
14    /// Unique workflow ID
15    pub workflow_id: WorkflowId,
16
17    /// Topic to generate content about
18    pub topic: String,
19
20    /// Optional custom max revisions (defaults to MAX_REVISIONS)
21    pub max_revisions: Option<u32>,
22}
23
24impl CritiqueLoopInput {
25    /// Creates a new critique loop input.
26    pub fn new(topic: impl Into<String>) -> Self {
27        Self {
28            workflow_id: WorkflowId::new(),
29            topic: topic.into(),
30            max_revisions: None,
31        }
32    }
33
34    /// Sets a custom max revisions limit.
35    pub fn with_max_revisions(mut self, max: u32) -> Self {
36        self.max_revisions = Some(max);
37        self
38    }
39}
40
41/// Output from critique-revise workflow.
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct CritiqueLoopOutput {
44    /// Workflow ID
45    pub workflow_id: WorkflowId,
46
47    /// Final approved text
48    pub final_text: String,
49
50    /// Number of revision iterations performed
51    pub revision_count: u32,
52
53    /// All critiques generated during the workflow
54    pub critiques: Vec<String>,
55}
56
57/// Workflow with critique and revision loop.
58#[derive(Clone)]
59pub struct CritiqueLoopWorkflow {
60    llm: Arc<dyn LlmProvider>,
61}
62
63impl CritiqueLoopWorkflow {
64    /// Creates a new critique loop workflow.
65    pub fn new(llm: Arc<dyn LlmProvider>) -> Self {
66        Self { llm }
67    }
68
69    /// Runs the critique-revise workflow with bounded iteration.
70    pub async fn run(&self, input: CritiqueLoopInput) -> Result<CritiqueLoopOutput> {
71        let max_revisions = input.max_revisions.unwrap_or(MAX_REVISIONS);
72
73        tracing::info!(
74            workflow_id = %input.workflow_id,
75            topic = %input.topic,
76            max_revisions = max_revisions,
77            "Starting critique-revise workflow"
78        );
79
80        // Step 1: Generate initial draft
81        let mut current_draft = self.generate_step(&input.topic).await?;
82
83        let mut revision_count = 0u32;
84        let mut critiques = Vec::new();
85
86        // Revision loop with bounded iterations
87        loop {
88            // Step 2: Critique current draft
89            let (critique_text, decision) =
90                self.critique_step(&current_draft, revision_count).await?;
91
92            critiques.push(critique_text.clone());
93
94            match decision {
95                CritiqueDecision::Pass => {
96                    tracing::info!(
97                        workflow_id = %input.workflow_id,
98                        revision_count,
99                        "Critique passed, workflow complete"
100                    );
101                    break;
102                }
103                CritiqueDecision::Revise { feedback } => {
104                    if revision_count >= max_revisions {
105                        tracing::warn!(
106                            workflow_id = %input.workflow_id,
107                            attempts = max_revisions,
108                            "Maximum revisions exceeded"
109                        );
110                        return Err(Error::MaxRevisionsExceeded {
111                            attempts: max_revisions,
112                        });
113                    }
114
115                    tracing::info!(
116                        workflow_id = %input.workflow_id,
117                        revision_count,
118                        feedback = %feedback,
119                        "Revision requested"
120                    );
121
122                    // Step 3: Revise based on feedback
123                    current_draft = self
124                        .revise_step(&current_draft, &feedback, revision_count)
125                        .await?;
126
127                    revision_count += 1;
128                }
129                // Handle future variants (non_exhaustive)
130                #[allow(unreachable_patterns)]
131                _ => {
132                    return Err(Error::validation("Unknown critique decision variant"));
133                }
134            }
135        }
136
137        tracing::info!(
138            workflow_id = %input.workflow_id,
139            revision_count,
140            "Critique-revise workflow completed"
141        );
142
143        Ok(CritiqueLoopOutput {
144            workflow_id: input.workflow_id,
145            final_text: current_draft,
146            revision_count,
147            critiques,
148        })
149    }
150
151    /// Generate initial content.
152    async fn generate_step(&self, topic: &str) -> Result<String> {
153        tracing::info!(topic = %topic, "Generating initial content");
154
155        let request = CompletionRequest::new(vec![Message::user(format!(
156            "Write a paragraph about: {}",
157            topic
158        ))])
159        .with_system_prompt("You are a content generator. Write clear paragraphs.")
160        .with_max_tokens(500);
161
162        let response = self.llm.complete(request).await?;
163
164        tracing::info!(tokens = response.tokens_used.total(), "Content generated");
165
166        Ok(response.content)
167    }
168
169    /// Critique content and decide if revision is needed.
170    async fn critique_step(
171        &self,
172        content: &str,
173        attempt: u32,
174    ) -> Result<(String, CritiqueDecision)> {
175        tracing::info!(attempt, "Critiquing content");
176
177        let request = CompletionRequest::new(vec![Message::user(format!(
178            "Critique this text and decide if it needs revision.\n\
179            Respond with JSON: {{\"decision\": \"pass\" or \"revise\", \"critique\": \"your critique\", \"feedback\": \"what to improve\"}}\n\n\
180            Text:\n{}",
181            content
182        ))])
183        .with_system_prompt("You are a writing critic. Be helpful but thorough.")
184        .with_max_tokens(400);
185
186        let response = self.llm.complete(request).await?;
187
188        // Parse JSON response
189        let parsed: serde_json::Value = serde_json::from_str(&response.content)
190            .map_err(|e| Error::validation(format!("Failed to parse critique JSON: {}", e)))?;
191
192        let critique = parsed["critique"]
193            .as_str()
194            .ok_or_else(|| Error::validation("Missing critique field"))?
195            .to_string();
196
197        let decision = match parsed["decision"].as_str() {
198            Some("pass") => CritiqueDecision::Pass,
199            Some("revise") => {
200                let feedback = parsed["feedback"]
201                    .as_str()
202                    .ok_or_else(|| Error::validation("Missing feedback for revise decision"))?
203                    .to_string();
204                CritiqueDecision::Revise { feedback }
205            }
206            _ => return Err(Error::validation("Invalid decision value")),
207        };
208
209        tracing::info!(
210            attempt,
211            decision = ?decision,
212            "Critique step completed"
213        );
214
215        Ok((critique, decision))
216    }
217
218    /// Revise content based on feedback.
219    async fn revise_step(&self, original: &str, feedback: &str, attempt: u32) -> Result<String> {
220        tracing::info!(attempt, "Revising content");
221
222        let request = CompletionRequest::new(vec![Message::user(format!(
223            "Revise this text based on the feedback:\n\n\
224            Original:\n{}\n\n\
225            Feedback:\n{}",
226            original, feedback
227        ))])
228        .with_system_prompt("You are a content editor. Improve the text based on feedback.")
229        .with_max_tokens(600);
230
231        let response = self.llm.complete(request).await?;
232
233        tracing::info!(
234            attempt,
235            tokens = response.tokens_used.total(),
236            "Revision completed"
237        );
238
239        Ok(response.content)
240    }
241}
242
243#[cfg(test)]
244#[allow(clippy::unwrap_used)]
245mod tests {
246    use super::*;
247    use ecl_core::llm::MockLlmProvider;
248
249    #[tokio::test]
250    async fn test_critique_loop_input_creation() {
251        let input = CritiqueLoopInput::new("Test topic");
252        assert_eq!(input.topic, "Test topic");
253        assert_eq!(input.max_revisions, None);
254    }
255
256    #[tokio::test]
257    async fn test_critique_loop_with_max_revisions() {
258        let input = CritiqueLoopInput::new("Test").with_max_revisions(5);
259        assert_eq!(input.max_revisions, Some(5));
260    }
261
262    #[tokio::test]
263    async fn test_critique_loop_pass_immediately() {
264        // Mock responses: generate, then critique that passes
265        let mock_llm = Arc::new(MockLlmProvider::new(vec![
266            "Generated content.".to_string(),
267            r#"{"decision": "pass", "critique": "Looks good!"}"#.to_string(),
268        ]));
269
270        let workflow = CritiqueLoopWorkflow::new(mock_llm);
271        let input = CritiqueLoopInput::new("Test topic");
272
273        let output = workflow.run(input.clone()).await.unwrap();
274
275        assert_eq!(output.workflow_id, input.workflow_id);
276        assert_eq!(output.final_text, "Generated content.");
277        assert_eq!(output.revision_count, 0);
278        assert_eq!(output.critiques.len(), 1);
279    }
280
281    #[tokio::test]
282    async fn test_critique_loop_with_one_revision() {
283        // Mock responses: generate, critique (revise), revise, critique (pass)
284        let mock_llm = Arc::new(MockLlmProvider::new(vec![
285            "Initial draft.".to_string(),
286            r#"{"decision": "revise", "critique": "Needs work", "feedback": "Add more detail"}"#
287                .to_string(),
288            "Improved draft with more detail.".to_string(),
289            r#"{"decision": "pass", "critique": "Much better!"}"#.to_string(),
290        ]));
291
292        let workflow = CritiqueLoopWorkflow::new(mock_llm);
293        let input = CritiqueLoopInput::new("Test topic");
294
295        let output = workflow.run(input).await.unwrap();
296
297        assert_eq!(output.final_text, "Improved draft with more detail.");
298        assert_eq!(output.revision_count, 1);
299        assert_eq!(output.critiques.len(), 2);
300    }
301
302    #[tokio::test]
303    async fn test_critique_loop_max_revisions_exceeded() {
304        // Mock responses that always request revision
305        let mock_llm = Arc::new(MockLlmProvider::new(vec![
306            "Draft.".to_string(),
307            r#"{"decision": "revise", "critique": "Try again", "feedback": "More work needed"}"#
308                .to_string(),
309            "Revised 1.".to_string(),
310            r#"{"decision": "revise", "critique": "Still not good", "feedback": "Keep trying"}"#
311                .to_string(),
312            "Revised 2.".to_string(),
313            r#"{"decision": "revise", "critique": "Nope", "feedback": "Again"}"#.to_string(),
314            "Revised 3.".to_string(),
315            r#"{"decision": "revise", "critique": "Still no", "feedback": "More"}"#.to_string(),
316        ]));
317
318        let workflow = CritiqueLoopWorkflow::new(mock_llm);
319        let input = CritiqueLoopInput::new("Test topic");
320
321        let result = workflow.run(input).await;
322
323        assert!(result.is_err());
324        let Error::MaxRevisionsExceeded { attempts } = result.unwrap_err() else {
325            unreachable!("Expected MaxRevisionsExceeded error");
326        };
327        assert_eq!(attempts, MAX_REVISIONS);
328    }
329
330    #[tokio::test]
331    async fn test_critique_loop_custom_max_revisions() {
332        let mock_llm = Arc::new(MockLlmProvider::new(vec![
333            "Draft.".to_string(),
334            r#"{"decision": "revise", "critique": "Revise", "feedback": "Improve"}"#.to_string(),
335            "Revised 1.".to_string(),
336            r#"{"decision": "revise", "critique": "Again", "feedback": "More"}"#.to_string(),
337        ]));
338
339        let workflow = CritiqueLoopWorkflow::new(mock_llm);
340        let input = CritiqueLoopInput::new("Test").with_max_revisions(1);
341
342        let result = workflow.run(input).await;
343
344        assert!(result.is_err());
345        let Error::MaxRevisionsExceeded { attempts } = result.unwrap_err() else {
346            unreachable!("Expected MaxRevisionsExceeded error with 1 attempt");
347        };
348        assert_eq!(attempts, 1);
349    }
350}