sapiens/
lib.rs

1//! Sapiens library
2//!
3//! *Sapiens uses tools to interact with the world.*
4//!
5//! An experiment with handing over the tools to the machine.
6//!
7//! # Overview
8//! This library is the core of Sapiens. It contains the logic for the
9//! interaction between the user, the language model and the tools.
10//!
11//! # More information
12//! See https://github.com/ssoudan/sapiens/tree/main/sapiens_cli for an example of usage or
13//! https://github.com/ssoudan/sapiens/tree/main/sapiens_bot for a Discord bot.
14//!
15//! https://github.com/ssoudan/sapiens/tree/main/sapiens_exp is a framework to run experiments and collect traces
16//! of the interactions between the language model and the tools to accomplish a
17//! task.
18//!
19//! A collection of tools is defined in https://github.com/ssoudan/sapiens/tree/main/sapiens_tools.
20pub mod context;
21
22/// Prompt generation logic
23pub mod prompt;
24
25/// Toolbox for sapiens
26pub mod tools;
27
28/// Language models
29pub mod models;
30
31pub mod chains;
32
33use std::fmt::Debug;
34use std::str::FromStr;
35use std::sync::{Arc, Weak};
36
37use clap::builder::PossibleValue;
38use serde::{Deserialize, Serialize};
39use tokio::sync::Mutex;
40
41use crate::chains::{Chain, Message, MultiStepOODAChain, SingleStepOODAChain};
42use crate::context::{ChatEntry, ContextDump};
43use crate::models::openai::OpenAI;
44use crate::models::{ModelRef, ModelResponse, Role, Usage};
45use crate::tools::invocation::InvocationError;
46use crate::tools::toolbox::{InvokeResult, Toolbox};
47use crate::tools::{TerminationMessage, ToolUseError};
48
49/// The error type for the bot
50#[derive(thiserror::Error, Debug)]
51pub enum Error {
52    /// Failed to add to the chat history
53    #[error("Failed to add to the chat history: {0}")]
54    ChatHistoryError(#[from] context::Error),
55    /// Model evaluation error
56    #[error("Model evaluation error: {0}")]
57    ModelEvaluationError(#[from] models::Error),
58    /// Reached the maximum number of steps
59    #[error("Maximal number of steps reached")]
60    MaxStepsReached,
61    /// The response is too long
62    #[error("The response is too long: {0}")]
63    ActionResponseTooLong(String),
64    /// Error in the chain
65    #[error("Chain error: {0}")]
66    ChainError(#[from] chains::Error),
67}
68
69/// Type of chain to use
70#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
71pub enum ChainType {
72    /// OODA single step chain
73    #[default]
74    SingleStepOODA,
75    /// OODA multi step chain
76    MultiStepOODA,
77}
78
79impl FromStr for ChainType {
80    type Err = String;
81
82    fn from_str(s: &str) -> Result<Self, Self::Err> {
83        match s {
84            "single-step-ooda" => Ok(ChainType::SingleStepOODA),
85            "multi-step-ooda" => Ok(ChainType::MultiStepOODA),
86            _ => Err(format!("Unknown chain type: {}", s)),
87        }
88    }
89}
90
91#[cfg(feature = "clap")]
92impl clap::ValueEnum for ChainType {
93    fn value_variants<'a>() -> &'a [Self] {
94        &[ChainType::SingleStepOODA, ChainType::MultiStepOODA]
95    }
96
97    fn to_possible_value(&self) -> Option<PossibleValue> {
98        match self {
99            ChainType::SingleStepOODA => Some(PossibleValue::new("single-step-ooda")),
100            ChainType::MultiStepOODA => Some(PossibleValue::new("multi-step-ooda")),
101        }
102    }
103}
104
105/// Configuration for the bot
106#[derive(Clone)]
107pub struct SapiensConfig {
108    /// The model to use
109    pub model: ModelRef,
110    /// The maximum number of steps
111    pub max_steps: usize,
112    /// The type of chain to use
113    pub chain_type: ChainType,
114    /// The minimum number of tokens that need to be available for completion
115    pub min_tokens_for_completion: usize,
116    /// Maximum number of tokens for the model to generate
117    pub max_tokens: Option<usize>,
118}
119
120impl Debug for SapiensConfig {
121    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122        f.debug_struct("Config")
123            .field("max_steps", &self.max_steps)
124            .field("chain_type", &self.chain_type)
125            .field("min_tokens_for_completion", &self.min_tokens_for_completion)
126            .field("max_tokens", &self.max_tokens)
127            .finish()
128    }
129}
130
131impl Default for SapiensConfig {
132    fn default() -> Self {
133        Self {
134            model: Arc::new(Box::<OpenAI>::default()),
135            max_steps: 10,
136            chain_type: ChainType::SingleStepOODA,
137            min_tokens_for_completion: 256,
138            max_tokens: None,
139        }
140    }
141}
142
143/// An update from the model
144#[derive(Debug, Clone)]
145pub struct ModelNotification {
146    /// The message from the model
147    pub chat_entry: ChatEntry,
148    /// The number of tokens used by the model
149    pub usage: Option<Usage>,
150}
151
152impl From<ModelResponse> for ModelNotification {
153    fn from(res: ModelResponse) -> Self {
154        Self {
155            chat_entry: ChatEntry {
156                role: Role::Assistant,
157                msg: res.msg,
158            },
159            usage: res.usage,
160        }
161    }
162}
163
164/// A message from a scheduler
165#[derive(Debug, Clone)]
166pub struct MessageNotification {
167    /// The message from the scheduler
168    pub message: Message,
169}
170
171impl From<Message> for MessageNotification {
172    fn from(message: Message) -> Self {
173        Self { message }
174    }
175}
176
177/// Notification of the result of a tool invocation
178pub enum InvocationResultNotification {
179    /// Invocation success notification
180    InvocationSuccess(InvocationSuccessNotification),
181    /// Invocation failure notification
182    InvocationFailure(InvocationFailureNotification),
183    /// Invalid invocation notification
184    InvalidInvocation(InvalidInvocationNotification),
185}
186
187impl From<InvokeResult> for InvocationResultNotification {
188    fn from(res: InvokeResult) -> Self {
189        match res {
190            InvokeResult::NoInvocationsFound { e } => {
191                InvocationResultNotification::InvalidInvocation(InvalidInvocationNotification {
192                    e,
193                    invocation_count: 0,
194                })
195            }
196            InvokeResult::NoValidInvocationsFound {
197                e,
198                invocation_count,
199            } => InvocationResultNotification::InvalidInvocation(InvalidInvocationNotification {
200                e,
201                invocation_count,
202            }),
203            InvokeResult::Success {
204                invocation_count,
205                tool_name,
206                extracted_input,
207                result,
208            } => InvocationResultNotification::InvocationSuccess(InvocationSuccessNotification {
209                invocation_count,
210                tool_name,
211                extracted_input,
212                result,
213            }),
214            InvokeResult::Error {
215                invocation_count,
216                tool_name,
217                extracted_input,
218                e,
219            } => InvocationResultNotification::InvocationFailure(InvocationFailureNotification {
220                invocation_count,
221                tool_name,
222                extracted_input,
223                e,
224            }),
225        }
226    }
227}
228
229/// Invocation success notification
230pub struct InvocationSuccessNotification {
231    /// The number of invocation blocks in the message
232    pub invocation_count: usize,
233    /// The tool name
234    pub tool_name: String,
235    /// The input that was extracted from the message and passed to `tool_name`
236    pub extracted_input: String,
237    /// The result
238    pub result: String,
239}
240
241/// Invocation failure notification
242pub struct InvocationFailureNotification {
243    /// Number of invocation  blocks in the message
244    pub invocation_count: usize,
245    /// The tool name
246    pub tool_name: String,
247    /// The input that was extracted from the message and passed to `tool_name`
248    pub extracted_input: String,
249    /// The result
250    pub e: ToolUseError,
251}
252
253/// Invalid invocation notification
254pub struct InvalidInvocationNotification {
255    /// The result
256    pub e: InvocationError,
257    /// Number of invocation blocks in the message
258    pub invocation_count: usize,
259}
260
261/// Termination notification
262pub struct TerminationNotification {
263    /// The messages
264    pub messages: Vec<TerminationMessage>,
265}
266
267/// Observer for the step progresses
268#[async_trait::async_trait]
269pub trait RuntimeObserver: Send {
270    /// Called when the task is submitted
271    async fn on_task(&mut self, _task: &str) {}
272
273    /// Called on start
274    async fn on_start(&mut self, _context: ContextDump) {}
275
276    /// Called when the model returns something
277    async fn on_model_update(&mut self, _event: ModelNotification) {}
278
279    /// Called when the scheduler has selected a message
280    async fn on_message(&mut self, _event: MessageNotification) {}
281
282    /// Called when the tool invocation was successful
283    async fn on_invocation_result(&mut self, _event: InvocationResultNotification) {}
284
285    /// Called when the task is done
286    async fn on_termination(&mut self, _event: TerminationNotification) {}
287}
288
289/// Wrap an observer into the a [`StrongRuntimeObserver<O>`] = [`Arc<Mutex<O>>`]
290///
291/// Use [`Arc::downgrade`] to get a [`Weak<Mutex<dyn RuntimeObserver>>`] and
292/// pass it to [`run_to_the_end`] for example.
293pub fn wrap_observer<O: RuntimeObserver + 'static>(observer: O) -> StrongRuntimeObserver<O> {
294    Arc::new(Mutex::new(observer))
295}
296
297/// A strong reference to the observer
298pub type StrongRuntimeObserver<O> = Arc<Mutex<O>>;
299
300/// A weak reference to the observer
301pub type WeakRuntimeObserver = Weak<Mutex<dyn RuntimeObserver>>;
302
303/// A void observer
304pub struct VoidTaskProgressUpdateObserver;
305
306#[cfg(test)]
307pub(crate) fn void_observer() -> StrongRuntimeObserver<VoidTaskProgressUpdateObserver> {
308    wrap_observer(VoidTaskProgressUpdateObserver)
309}
310
311#[async_trait::async_trait]
312impl RuntimeObserver for VoidTaskProgressUpdateObserver {}
313
314/// A step in the task
315pub struct Step {
316    task_chain: Box<dyn Chain>,
317    observer: WeakRuntimeObserver,
318}
319
320impl Step {
321    /// Run the task for a single step
322    async fn step(mut self) -> Result<TaskState, Error> {
323        let termination_messages = self.task_chain.step().await?;
324
325        // check if the task is done
326        if !termination_messages.is_empty() {
327            if let Some(observer) = self.observer.upgrade() {
328                observer
329                    .lock()
330                    .await
331                    .on_termination(TerminationNotification {
332                        messages: termination_messages.clone(),
333                    })
334                    .await;
335            }
336
337            return Ok(TaskState::Stop {
338                stop: Stop {
339                    termination_messages,
340                },
341            });
342        }
343
344        Ok(TaskState::Step { step: self })
345    }
346}
347
348/// The task is done
349pub struct Stop {
350    /// The termination messages
351    pub termination_messages: Vec<TerminationMessage>,
352}
353
354/// The state machine of a task
355pub enum TaskState {
356    /// The task is not done yet
357    Step {
358        /// The actual step task
359        step: Step,
360    },
361    /// The task is done
362    Stop {
363        /// the actual stopped task
364        stop: Stop,
365    },
366}
367
368impl TaskState {
369    /// Create a new [`TaskState`] for a `task`.
370    pub async fn new(config: SapiensConfig, toolbox: Toolbox, task: String) -> Result<Self, Error> {
371        let observer = wrap_observer(VoidTaskProgressUpdateObserver {});
372        let observer = Arc::downgrade(&observer);
373
374        TaskState::with_observer(config, toolbox, task, observer).await
375    }
376
377    /// Create a new [`TaskState`] for a `task`.
378    ///
379    /// The `observer` will be called when the task starts and when a step is
380    /// completed - either successfully or not. The `observer` will be called
381    /// with the latest chat history element. It is also called on error.
382    pub async fn with_observer(
383        config: SapiensConfig,
384        toolbox: Toolbox,
385        task: String,
386        observer: WeakRuntimeObserver,
387    ) -> Result<Self, Error> {
388        if let Some(observer) = observer.upgrade() {
389            observer.lock().await.on_task(&task).await;
390        }
391
392        let task_chain = match config.chain_type {
393            ChainType::SingleStepOODA => {
394                let chain = SingleStepOODAChain::new(config, toolbox, observer.clone())
395                    .await?
396                    .with_task(task);
397                Box::new(chain) as Box<dyn Chain>
398            }
399            ChainType::MultiStepOODA => {
400                let chain = MultiStepOODAChain::new(config, toolbox, observer.clone())
401                    .await?
402                    .with_task(task);
403                Box::new(chain) as Box<dyn Chain>
404            }
405        };
406
407        // call the observer
408        if let Some(observer) = observer.upgrade() {
409            observer.lock().await.on_start(task_chain.dump()).await;
410        }
411
412        Ok(TaskState::Step {
413            step: Step {
414                task_chain,
415                observer,
416            },
417        })
418    }
419
420    /// Run the task until it is done
421    pub async fn run(mut self) -> Result<Stop, Error> {
422        loop {
423            match self {
424                TaskState::Step { step } => {
425                    self = step.step().await?;
426                }
427                TaskState::Stop { stop } => {
428                    return Ok(stop);
429                }
430            }
431        }
432    }
433
434    /// Run the task for a single step
435    pub async fn step(self) -> Result<Self, Error> {
436        match self {
437            TaskState::Step { step } => step.step().await,
438            TaskState::Stop { stop } => Ok(TaskState::Stop { stop }),
439        }
440    }
441
442    /// is the task done?
443    pub fn is_done(&self) -> Option<Vec<TerminationMessage>> {
444        match self {
445            TaskState::Step { step: _ } => None,
446            TaskState::Stop { stop } => Some(stop.termination_messages.clone()),
447        }
448    }
449}
450
451/// Run until the task is done or the maximum number of steps is reached
452///
453/// See [`TaskState::new`], [`TaskState::step`] and [`TaskState::run`] for
454/// more flexible ways to run a task
455#[tracing::instrument(skip(toolbox, observer, config))]
456pub async fn run_to_the_end(
457    config: SapiensConfig,
458    toolbox: Toolbox,
459    task: String,
460    observer: WeakRuntimeObserver,
461) -> Result<Vec<TerminationMessage>, Error> {
462    let task_state = TaskState::with_observer(config, toolbox, task, observer).await?;
463
464    let stop = task_state.run().await?;
465
466    Ok(stop.termination_messages)
467}