1use ecl_core::llm::{CompletionRequest, LlmProvider, Message};
4use ecl_core::{CritiqueDecision, Error, Result, WorkflowId};
5use serde::{Deserialize, Serialize};
6use std::sync::Arc;
7
8const MAX_REVISIONS: u32 = 3;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct CritiqueLoopInput {
14 pub workflow_id: WorkflowId,
16
17 pub topic: String,
19
20 pub max_revisions: Option<u32>,
22}
23
24impl CritiqueLoopInput {
25 pub fn new(topic: impl Into<String>) -> Self {
27 Self {
28 workflow_id: WorkflowId::new(),
29 topic: topic.into(),
30 max_revisions: None,
31 }
32 }
33
34 pub fn with_max_revisions(mut self, max: u32) -> Self {
36 self.max_revisions = Some(max);
37 self
38 }
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct CritiqueLoopOutput {
44 pub workflow_id: WorkflowId,
46
47 pub final_text: String,
49
50 pub revision_count: u32,
52
53 pub critiques: Vec<String>,
55}
56
57#[derive(Clone)]
59pub struct CritiqueLoopWorkflow {
60 llm: Arc<dyn LlmProvider>,
61}
62
63impl CritiqueLoopWorkflow {
64 pub fn new(llm: Arc<dyn LlmProvider>) -> Self {
66 Self { llm }
67 }
68
69 pub async fn run(&self, input: CritiqueLoopInput) -> Result<CritiqueLoopOutput> {
71 let max_revisions = input.max_revisions.unwrap_or(MAX_REVISIONS);
72
73 tracing::info!(
74 workflow_id = %input.workflow_id,
75 topic = %input.topic,
76 max_revisions = max_revisions,
77 "Starting critique-revise workflow"
78 );
79
80 let mut current_draft = self.generate_step(&input.topic).await?;
82
83 let mut revision_count = 0u32;
84 let mut critiques = Vec::new();
85
86 loop {
88 let (critique_text, decision) =
90 self.critique_step(¤t_draft, revision_count).await?;
91
92 critiques.push(critique_text.clone());
93
94 match decision {
95 CritiqueDecision::Pass => {
96 tracing::info!(
97 workflow_id = %input.workflow_id,
98 revision_count,
99 "Critique passed, workflow complete"
100 );
101 break;
102 }
103 CritiqueDecision::Revise { feedback } => {
104 if revision_count >= max_revisions {
105 tracing::warn!(
106 workflow_id = %input.workflow_id,
107 attempts = max_revisions,
108 "Maximum revisions exceeded"
109 );
110 return Err(Error::MaxRevisionsExceeded {
111 attempts: max_revisions,
112 });
113 }
114
115 tracing::info!(
116 workflow_id = %input.workflow_id,
117 revision_count,
118 feedback = %feedback,
119 "Revision requested"
120 );
121
122 current_draft = self
124 .revise_step(¤t_draft, &feedback, revision_count)
125 .await?;
126
127 revision_count += 1;
128 }
129 #[allow(unreachable_patterns)]
131 _ => {
132 return Err(Error::validation("Unknown critique decision variant"));
133 }
134 }
135 }
136
137 tracing::info!(
138 workflow_id = %input.workflow_id,
139 revision_count,
140 "Critique-revise workflow completed"
141 );
142
143 Ok(CritiqueLoopOutput {
144 workflow_id: input.workflow_id,
145 final_text: current_draft,
146 revision_count,
147 critiques,
148 })
149 }
150
151 async fn generate_step(&self, topic: &str) -> Result<String> {
153 tracing::info!(topic = %topic, "Generating initial content");
154
155 let request = CompletionRequest::new(vec![Message::user(format!(
156 "Write a paragraph about: {}",
157 topic
158 ))])
159 .with_system_prompt("You are a content generator. Write clear paragraphs.")
160 .with_max_tokens(500);
161
162 let response = self.llm.complete(request).await?;
163
164 tracing::info!(tokens = response.tokens_used.total(), "Content generated");
165
166 Ok(response.content)
167 }
168
169 async fn critique_step(
171 &self,
172 content: &str,
173 attempt: u32,
174 ) -> Result<(String, CritiqueDecision)> {
175 tracing::info!(attempt, "Critiquing content");
176
177 let request = CompletionRequest::new(vec![Message::user(format!(
178 "Critique this text and decide if it needs revision.\n\
179 Respond with JSON: {{\"decision\": \"pass\" or \"revise\", \"critique\": \"your critique\", \"feedback\": \"what to improve\"}}\n\n\
180 Text:\n{}",
181 content
182 ))])
183 .with_system_prompt("You are a writing critic. Be helpful but thorough.")
184 .with_max_tokens(400);
185
186 let response = self.llm.complete(request).await?;
187
188 let parsed: serde_json::Value = serde_json::from_str(&response.content)
190 .map_err(|e| Error::validation(format!("Failed to parse critique JSON: {}", e)))?;
191
192 let critique = parsed["critique"]
193 .as_str()
194 .ok_or_else(|| Error::validation("Missing critique field"))?
195 .to_string();
196
197 let decision = match parsed["decision"].as_str() {
198 Some("pass") => CritiqueDecision::Pass,
199 Some("revise") => {
200 let feedback = parsed["feedback"]
201 .as_str()
202 .ok_or_else(|| Error::validation("Missing feedback for revise decision"))?
203 .to_string();
204 CritiqueDecision::Revise { feedback }
205 }
206 _ => return Err(Error::validation("Invalid decision value")),
207 };
208
209 tracing::info!(
210 attempt,
211 decision = ?decision,
212 "Critique step completed"
213 );
214
215 Ok((critique, decision))
216 }
217
218 async fn revise_step(&self, original: &str, feedback: &str, attempt: u32) -> Result<String> {
220 tracing::info!(attempt, "Revising content");
221
222 let request = CompletionRequest::new(vec![Message::user(format!(
223 "Revise this text based on the feedback:\n\n\
224 Original:\n{}\n\n\
225 Feedback:\n{}",
226 original, feedback
227 ))])
228 .with_system_prompt("You are a content editor. Improve the text based on feedback.")
229 .with_max_tokens(600);
230
231 let response = self.llm.complete(request).await?;
232
233 tracing::info!(
234 attempt,
235 tokens = response.tokens_used.total(),
236 "Revision completed"
237 );
238
239 Ok(response.content)
240 }
241}
242
243#[cfg(test)]
244#[allow(clippy::unwrap_used)]
245mod tests {
246 use super::*;
247 use ecl_core::llm::MockLlmProvider;
248
249 #[tokio::test]
250 async fn test_critique_loop_input_creation() {
251 let input = CritiqueLoopInput::new("Test topic");
252 assert_eq!(input.topic, "Test topic");
253 assert_eq!(input.max_revisions, None);
254 }
255
256 #[tokio::test]
257 async fn test_critique_loop_with_max_revisions() {
258 let input = CritiqueLoopInput::new("Test").with_max_revisions(5);
259 assert_eq!(input.max_revisions, Some(5));
260 }
261
262 #[tokio::test]
263 async fn test_critique_loop_pass_immediately() {
264 let mock_llm = Arc::new(MockLlmProvider::new(vec![
266 "Generated content.".to_string(),
267 r#"{"decision": "pass", "critique": "Looks good!"}"#.to_string(),
268 ]));
269
270 let workflow = CritiqueLoopWorkflow::new(mock_llm);
271 let input = CritiqueLoopInput::new("Test topic");
272
273 let output = workflow.run(input.clone()).await.unwrap();
274
275 assert_eq!(output.workflow_id, input.workflow_id);
276 assert_eq!(output.final_text, "Generated content.");
277 assert_eq!(output.revision_count, 0);
278 assert_eq!(output.critiques.len(), 1);
279 }
280
281 #[tokio::test]
282 async fn test_critique_loop_with_one_revision() {
283 let mock_llm = Arc::new(MockLlmProvider::new(vec![
285 "Initial draft.".to_string(),
286 r#"{"decision": "revise", "critique": "Needs work", "feedback": "Add more detail"}"#
287 .to_string(),
288 "Improved draft with more detail.".to_string(),
289 r#"{"decision": "pass", "critique": "Much better!"}"#.to_string(),
290 ]));
291
292 let workflow = CritiqueLoopWorkflow::new(mock_llm);
293 let input = CritiqueLoopInput::new("Test topic");
294
295 let output = workflow.run(input).await.unwrap();
296
297 assert_eq!(output.final_text, "Improved draft with more detail.");
298 assert_eq!(output.revision_count, 1);
299 assert_eq!(output.critiques.len(), 2);
300 }
301
302 #[tokio::test]
303 async fn test_critique_loop_max_revisions_exceeded() {
304 let mock_llm = Arc::new(MockLlmProvider::new(vec![
306 "Draft.".to_string(),
307 r#"{"decision": "revise", "critique": "Try again", "feedback": "More work needed"}"#
308 .to_string(),
309 "Revised 1.".to_string(),
310 r#"{"decision": "revise", "critique": "Still not good", "feedback": "Keep trying"}"#
311 .to_string(),
312 "Revised 2.".to_string(),
313 r#"{"decision": "revise", "critique": "Nope", "feedback": "Again"}"#.to_string(),
314 "Revised 3.".to_string(),
315 r#"{"decision": "revise", "critique": "Still no", "feedback": "More"}"#.to_string(),
316 ]));
317
318 let workflow = CritiqueLoopWorkflow::new(mock_llm);
319 let input = CritiqueLoopInput::new("Test topic");
320
321 let result = workflow.run(input).await;
322
323 assert!(result.is_err());
324 let Error::MaxRevisionsExceeded { attempts } = result.unwrap_err() else {
325 unreachable!("Expected MaxRevisionsExceeded error");
326 };
327 assert_eq!(attempts, MAX_REVISIONS);
328 }
329
330 #[tokio::test]
331 async fn test_critique_loop_custom_max_revisions() {
332 let mock_llm = Arc::new(MockLlmProvider::new(vec![
333 "Draft.".to_string(),
334 r#"{"decision": "revise", "critique": "Revise", "feedback": "Improve"}"#.to_string(),
335 "Revised 1.".to_string(),
336 r#"{"decision": "revise", "critique": "Again", "feedback": "More"}"#.to_string(),
337 ]));
338
339 let workflow = CritiqueLoopWorkflow::new(mock_llm);
340 let input = CritiqueLoopInput::new("Test").with_max_revisions(1);
341
342 let result = workflow.run(input).await;
343
344 assert!(result.is_err());
345 let Error::MaxRevisionsExceeded { attempts } = result.unwrap_err() else {
346 unreachable!("Expected MaxRevisionsExceeded error with 1 attempt");
347 };
348 assert_eq!(attempts, 1);
349 }
350}