1pub mod context;
21
22pub mod prompt;
24
25pub mod tools;
27
28pub mod models;
30
31pub mod chains;
32
33use std::fmt::Debug;
34use std::str::FromStr;
35use std::sync::{Arc, Weak};
36
37use clap::builder::PossibleValue;
38use serde::{Deserialize, Serialize};
39use tokio::sync::Mutex;
40
41use crate::chains::{Chain, Message, MultiStepOODAChain, SingleStepOODAChain};
42use crate::context::{ChatEntry, ContextDump};
43use crate::models::openai::OpenAI;
44use crate::models::{ModelRef, ModelResponse, Role, Usage};
45use crate::tools::invocation::InvocationError;
46use crate::tools::toolbox::{InvokeResult, Toolbox};
47use crate::tools::{TerminationMessage, ToolUseError};
48
49#[derive(thiserror::Error, Debug)]
51pub enum Error {
52 #[error("Failed to add to the chat history: {0}")]
54 ChatHistoryError(#[from] context::Error),
55 #[error("Model evaluation error: {0}")]
57 ModelEvaluationError(#[from] models::Error),
58 #[error("Maximal number of steps reached")]
60 MaxStepsReached,
61 #[error("The response is too long: {0}")]
63 ActionResponseTooLong(String),
64 #[error("Chain error: {0}")]
66 ChainError(#[from] chains::Error),
67}
68
69#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
71pub enum ChainType {
72 #[default]
74 SingleStepOODA,
75 MultiStepOODA,
77}
78
79impl FromStr for ChainType {
80 type Err = String;
81
82 fn from_str(s: &str) -> Result<Self, Self::Err> {
83 match s {
84 "single-step-ooda" => Ok(ChainType::SingleStepOODA),
85 "multi-step-ooda" => Ok(ChainType::MultiStepOODA),
86 _ => Err(format!("Unknown chain type: {}", s)),
87 }
88 }
89}
90
91#[cfg(feature = "clap")]
92impl clap::ValueEnum for ChainType {
93 fn value_variants<'a>() -> &'a [Self] {
94 &[ChainType::SingleStepOODA, ChainType::MultiStepOODA]
95 }
96
97 fn to_possible_value(&self) -> Option<PossibleValue> {
98 match self {
99 ChainType::SingleStepOODA => Some(PossibleValue::new("single-step-ooda")),
100 ChainType::MultiStepOODA => Some(PossibleValue::new("multi-step-ooda")),
101 }
102 }
103}
104
105#[derive(Clone)]
107pub struct SapiensConfig {
108 pub model: ModelRef,
110 pub max_steps: usize,
112 pub chain_type: ChainType,
114 pub min_tokens_for_completion: usize,
116 pub max_tokens: Option<usize>,
118}
119
120impl Debug for SapiensConfig {
121 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122 f.debug_struct("Config")
123 .field("max_steps", &self.max_steps)
124 .field("chain_type", &self.chain_type)
125 .field("min_tokens_for_completion", &self.min_tokens_for_completion)
126 .field("max_tokens", &self.max_tokens)
127 .finish()
128 }
129}
130
131impl Default for SapiensConfig {
132 fn default() -> Self {
133 Self {
134 model: Arc::new(Box::<OpenAI>::default()),
135 max_steps: 10,
136 chain_type: ChainType::SingleStepOODA,
137 min_tokens_for_completion: 256,
138 max_tokens: None,
139 }
140 }
141}
142
143#[derive(Debug, Clone)]
145pub struct ModelNotification {
146 pub chat_entry: ChatEntry,
148 pub usage: Option<Usage>,
150}
151
152impl From<ModelResponse> for ModelNotification {
153 fn from(res: ModelResponse) -> Self {
154 Self {
155 chat_entry: ChatEntry {
156 role: Role::Assistant,
157 msg: res.msg,
158 },
159 usage: res.usage,
160 }
161 }
162}
163
164#[derive(Debug, Clone)]
166pub struct MessageNotification {
167 pub message: Message,
169}
170
171impl From<Message> for MessageNotification {
172 fn from(message: Message) -> Self {
173 Self { message }
174 }
175}
176
177pub enum InvocationResultNotification {
179 InvocationSuccess(InvocationSuccessNotification),
181 InvocationFailure(InvocationFailureNotification),
183 InvalidInvocation(InvalidInvocationNotification),
185}
186
187impl From<InvokeResult> for InvocationResultNotification {
188 fn from(res: InvokeResult) -> Self {
189 match res {
190 InvokeResult::NoInvocationsFound { e } => {
191 InvocationResultNotification::InvalidInvocation(InvalidInvocationNotification {
192 e,
193 invocation_count: 0,
194 })
195 }
196 InvokeResult::NoValidInvocationsFound {
197 e,
198 invocation_count,
199 } => InvocationResultNotification::InvalidInvocation(InvalidInvocationNotification {
200 e,
201 invocation_count,
202 }),
203 InvokeResult::Success {
204 invocation_count,
205 tool_name,
206 extracted_input,
207 result,
208 } => InvocationResultNotification::InvocationSuccess(InvocationSuccessNotification {
209 invocation_count,
210 tool_name,
211 extracted_input,
212 result,
213 }),
214 InvokeResult::Error {
215 invocation_count,
216 tool_name,
217 extracted_input,
218 e,
219 } => InvocationResultNotification::InvocationFailure(InvocationFailureNotification {
220 invocation_count,
221 tool_name,
222 extracted_input,
223 e,
224 }),
225 }
226 }
227}
228
229pub struct InvocationSuccessNotification {
231 pub invocation_count: usize,
233 pub tool_name: String,
235 pub extracted_input: String,
237 pub result: String,
239}
240
241pub struct InvocationFailureNotification {
243 pub invocation_count: usize,
245 pub tool_name: String,
247 pub extracted_input: String,
249 pub e: ToolUseError,
251}
252
253pub struct InvalidInvocationNotification {
255 pub e: InvocationError,
257 pub invocation_count: usize,
259}
260
261pub struct TerminationNotification {
263 pub messages: Vec<TerminationMessage>,
265}
266
267#[async_trait::async_trait]
269pub trait RuntimeObserver: Send {
270 async fn on_task(&mut self, _task: &str) {}
272
273 async fn on_start(&mut self, _context: ContextDump) {}
275
276 async fn on_model_update(&mut self, _event: ModelNotification) {}
278
279 async fn on_message(&mut self, _event: MessageNotification) {}
281
282 async fn on_invocation_result(&mut self, _event: InvocationResultNotification) {}
284
285 async fn on_termination(&mut self, _event: TerminationNotification) {}
287}
288
289pub fn wrap_observer<O: RuntimeObserver + 'static>(observer: O) -> StrongRuntimeObserver<O> {
294 Arc::new(Mutex::new(observer))
295}
296
297pub type StrongRuntimeObserver<O> = Arc<Mutex<O>>;
299
300pub type WeakRuntimeObserver = Weak<Mutex<dyn RuntimeObserver>>;
302
303pub struct VoidTaskProgressUpdateObserver;
305
306#[cfg(test)]
307pub(crate) fn void_observer() -> StrongRuntimeObserver<VoidTaskProgressUpdateObserver> {
308 wrap_observer(VoidTaskProgressUpdateObserver)
309}
310
311#[async_trait::async_trait]
312impl RuntimeObserver for VoidTaskProgressUpdateObserver {}
313
314pub struct Step {
316 task_chain: Box<dyn Chain>,
317 observer: WeakRuntimeObserver,
318}
319
320impl Step {
321 async fn step(mut self) -> Result<TaskState, Error> {
323 let termination_messages = self.task_chain.step().await?;
324
325 if !termination_messages.is_empty() {
327 if let Some(observer) = self.observer.upgrade() {
328 observer
329 .lock()
330 .await
331 .on_termination(TerminationNotification {
332 messages: termination_messages.clone(),
333 })
334 .await;
335 }
336
337 return Ok(TaskState::Stop {
338 stop: Stop {
339 termination_messages,
340 },
341 });
342 }
343
344 Ok(TaskState::Step { step: self })
345 }
346}
347
348pub struct Stop {
350 pub termination_messages: Vec<TerminationMessage>,
352}
353
354pub enum TaskState {
356 Step {
358 step: Step,
360 },
361 Stop {
363 stop: Stop,
365 },
366}
367
368impl TaskState {
369 pub async fn new(config: SapiensConfig, toolbox: Toolbox, task: String) -> Result<Self, Error> {
371 let observer = wrap_observer(VoidTaskProgressUpdateObserver {});
372 let observer = Arc::downgrade(&observer);
373
374 TaskState::with_observer(config, toolbox, task, observer).await
375 }
376
377 pub async fn with_observer(
383 config: SapiensConfig,
384 toolbox: Toolbox,
385 task: String,
386 observer: WeakRuntimeObserver,
387 ) -> Result<Self, Error> {
388 if let Some(observer) = observer.upgrade() {
389 observer.lock().await.on_task(&task).await;
390 }
391
392 let task_chain = match config.chain_type {
393 ChainType::SingleStepOODA => {
394 let chain = SingleStepOODAChain::new(config, toolbox, observer.clone())
395 .await?
396 .with_task(task);
397 Box::new(chain) as Box<dyn Chain>
398 }
399 ChainType::MultiStepOODA => {
400 let chain = MultiStepOODAChain::new(config, toolbox, observer.clone())
401 .await?
402 .with_task(task);
403 Box::new(chain) as Box<dyn Chain>
404 }
405 };
406
407 if let Some(observer) = observer.upgrade() {
409 observer.lock().await.on_start(task_chain.dump()).await;
410 }
411
412 Ok(TaskState::Step {
413 step: Step {
414 task_chain,
415 observer,
416 },
417 })
418 }
419
420 pub async fn run(mut self) -> Result<Stop, Error> {
422 loop {
423 match self {
424 TaskState::Step { step } => {
425 self = step.step().await?;
426 }
427 TaskState::Stop { stop } => {
428 return Ok(stop);
429 }
430 }
431 }
432 }
433
434 pub async fn step(self) -> Result<Self, Error> {
436 match self {
437 TaskState::Step { step } => step.step().await,
438 TaskState::Stop { stop } => Ok(TaskState::Stop { stop }),
439 }
440 }
441
442 pub fn is_done(&self) -> Option<Vec<TerminationMessage>> {
444 match self {
445 TaskState::Step { step: _ } => None,
446 TaskState::Stop { stop } => Some(stop.termination_messages.clone()),
447 }
448 }
449}
450
451#[tracing::instrument(skip(toolbox, observer, config))]
456pub async fn run_to_the_end(
457 config: SapiensConfig,
458 toolbox: Toolbox,
459 task: String,
460 observer: WeakRuntimeObserver,
461) -> Result<Vec<TerminationMessage>, Error> {
462 let task_state = TaskState::with_observer(config, toolbox, task, observer).await?;
463
464 let stop = task_state.run().await?;
465
466 Ok(stop.termination_messages)
467}