use std::sync::Arc;
use async_trait::async_trait;
use cognis_core::{Message, Result};
use cognis_llm::chat::ChatResponse;
use super::{Middleware, MiddlewareCtx, Next};
const MARKER: &str = "<!-- cognis:planning-mw -->";
const DEFAULT_INSTRUCTION: &str =
"Before taking any action, write out a short numbered plan of the steps \
you'll take. Then execute each step. After completing all steps, \
summarize the result.";
pub struct Planning {
instruction: String,
}
impl Default for Planning {
fn default() -> Self {
Self::new()
}
}
impl Planning {
pub fn new() -> Self {
Self {
instruction: DEFAULT_INSTRUCTION.to_string(),
}
}
pub fn with_instruction(mut self, s: impl Into<String>) -> Self {
self.instruction = s.into();
self
}
}
#[async_trait]
impl Middleware for Planning {
async fn call(&self, mut ctx: MiddlewareCtx, next: Arc<dyn Next>) -> Result<ChatResponse> {
let already = ctx
.messages
.iter()
.any(|m| matches!(m, Message::System(s) if s.content.contains(MARKER)));
if !already {
let body = format!("{MARKER}\n{}", self.instruction);
ctx.messages.insert(0, Message::system(body));
}
next.invoke(ctx).await
}
fn name(&self) -> &str {
"Planning"
}
}
#[cfg(test)]
mod tests {
use super::super::tests_util::*;
use super::*;
use crate::middleware::MiddlewarePipeline;
use cognis_llm::chat::ChatOptions;
use cognis_llm::Client;
#[tokio::test]
async fn injects_planning_instruction() {
let rec = make_recording_provider("ok");
let pipe = MiddlewarePipeline::new()
.push(Planning::new())
.build(Client::new(rec.clone()));
let _ = pipe
.invoke(
vec![Message::human("solve x")],
Vec::new(),
ChatOptions::default(),
)
.await
.unwrap();
let received = rec.received.lock().unwrap();
assert!(received[0].0[0].content().contains(MARKER));
}
#[tokio::test]
async fn idempotent_no_double_insert() {
let rec = make_recording_provider("ok");
let pipe = MiddlewarePipeline::new()
.push(Planning::new())
.build(Client::new(rec.clone()));
let already = vec![
Message::system(format!("{MARKER}\nold")),
Message::human("hi"),
];
let _ = pipe
.invoke(already, Vec::new(), ChatOptions::default())
.await
.unwrap();
let received = rec.received.lock().unwrap();
let count = received[0]
.0
.iter()
.filter(|m| m.content().contains(MARKER))
.count();
assert_eq!(count, 1);
}
}