Skip to main content

neuron_types/
traits.rs

1//! Core traits: Provider, Tool, ToolDyn, ContextStrategy, ObservabilityHook, DurableContext.
2
3use std::future::Future;
4use std::time::Duration;
5
6use serde::Serialize;
7use serde::de::DeserializeOwned;
8
9use crate::error::{
10    ContextError, DurableError, EmbeddingError, HookError, ProviderError, ToolError,
11};
12use crate::stream::StreamHandle;
13use crate::types::{
14    CompletionRequest, CompletionResponse, ContentItem, EmbeddingRequest, EmbeddingResponse,
15    Message, ToolContext, ToolDefinition, ToolOutput,
16};
17use crate::wasm::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync};
18
19/// LLM provider trait. Implement this for each provider (Anthropic, OpenAI, Ollama, etc.).
20///
21/// Uses RPITIT (return position impl trait in trait) — Rust 2024 native async.
22/// Not object-safe by design; use generics `<P: Provider>` to compose.
23///
24/// # Example
25///
26/// ```no_run
27/// use std::future::Future;
28/// use neuron_types::*;
29///
30/// struct MyProvider;
31///
32/// impl Provider for MyProvider {
33///     fn complete(&self, request: CompletionRequest)
34///         -> impl Future<Output = Result<CompletionResponse, ProviderError>> + Send
35///     {
36///         async { todo!() }
37///     }
38///
39///     fn complete_stream(&self, request: CompletionRequest)
40///         -> impl Future<Output = Result<StreamHandle, ProviderError>> + Send
41///     {
42///         async { todo!() }
43///     }
44/// }
45/// ```
46pub trait Provider: WasmCompatSend + WasmCompatSync {
47    /// Send a completion request and get a full response.
48    fn complete(
49        &self,
50        request: CompletionRequest,
51    ) -> impl Future<Output = Result<CompletionResponse, ProviderError>> + WasmCompatSend;
52
53    /// Send a completion request and get a stream of events.
54    fn complete_stream(
55        &self,
56        request: CompletionRequest,
57    ) -> impl Future<Output = Result<StreamHandle, ProviderError>> + WasmCompatSend;
58}
59
60/// Embedding provider trait. Implement this for providers that support text embeddings.
61///
62/// Kept separate from [`Provider`] because not all embedding models support chat completion
63/// and not all chat providers support embeddings. Implement both traits on a struct when a
64/// provider supports both capabilities.
65///
66/// Uses RPITIT (return position impl trait in trait) — Rust 2024 native async.
67/// Not object-safe by design; use generics `<E: EmbeddingProvider>` to compose.
68///
69/// # Example
70///
71/// ```no_run
72/// use std::future::Future;
73/// use neuron_types::*;
74///
75/// struct MyEmbeddingProvider;
76///
77/// impl EmbeddingProvider for MyEmbeddingProvider {
78///     fn embed(&self, request: EmbeddingRequest)
79///         -> impl Future<Output = Result<EmbeddingResponse, EmbeddingError>> + Send
80///     {
81///         async { todo!() }
82///     }
83/// }
84/// ```
85pub trait EmbeddingProvider: WasmCompatSend + WasmCompatSync {
86    /// Send an embedding request and get vectors back.
87    ///
88    /// Multiple input strings are batched into a single request. The returned
89    /// `embeddings` vec is in the same order as `request.input`.
90    fn embed(
91        &self,
92        request: EmbeddingRequest,
93    ) -> impl Future<Output = Result<EmbeddingResponse, EmbeddingError>> + WasmCompatSend;
94}
95
96/// Strongly-typed tool trait. Implement this for your tools.
97///
98/// The blanket impl of [`ToolDyn`] handles JSON deserialization/serialization
99/// so you work with concrete Rust types.
100///
101/// # Example
102///
103/// ```no_run
104/// use neuron_types::*;
105/// use serde::Deserialize;
106///
107/// #[derive(Debug, Deserialize, schemars::JsonSchema)]
108/// struct MyArgs { query: String }
109///
110/// struct MyTool;
111/// impl Tool for MyTool {
112///     const NAME: &'static str = "my_tool";
113///     type Args = MyArgs;
114///     type Output = String;
115///     type Error = std::io::Error;
116///
117///     fn definition(&self) -> ToolDefinition { todo!() }
118///     fn call(&self, args: MyArgs, ctx: &ToolContext)
119///         -> impl Future<Output = Result<String, std::io::Error>> + Send
120///     { async { Ok(args.query) } }
121/// }
122/// ```
123pub trait Tool: WasmCompatSend + WasmCompatSync {
124    /// The unique name of this tool.
125    const NAME: &'static str;
126    /// The deserialized input type.
127    type Args: DeserializeOwned + schemars::JsonSchema + WasmCompatSend;
128    /// The serializable output type.
129    type Output: Serialize;
130    /// The tool-specific error type.
131    type Error: std::error::Error + WasmCompatSend + 'static;
132
133    /// Returns the tool definition (name, description, schema).
134    fn definition(&self) -> ToolDefinition;
135
136    /// Execute the tool with typed arguments.
137    fn call(
138        &self,
139        args: Self::Args,
140        ctx: &ToolContext,
141    ) -> impl Future<Output = Result<Self::Output, Self::Error>> + WasmCompatSend;
142}
143
144/// Type-erased tool for dynamic dispatch. Blanket-implemented for all [`Tool`] impls.
145///
146/// This enables heterogeneous tool collections (`HashMap<String, Arc<dyn ToolDyn>>`)
147/// while preserving type safety at the implementation level.
148pub trait ToolDyn: WasmCompatSend + WasmCompatSync {
149    /// The tool's unique name.
150    fn name(&self) -> &str;
151    /// The tool definition (name, description, input schema).
152    fn definition(&self) -> ToolDefinition;
153    /// Execute the tool with a JSON value input, returning a generic output.
154    fn call_dyn<'a>(
155        &'a self,
156        input: serde_json::Value,
157        ctx: &'a ToolContext,
158    ) -> WasmBoxedFuture<'a, Result<ToolOutput, ToolError>>;
159}
160
161/// Blanket implementation: any `Tool` automatically becomes a `ToolDyn`.
162///
163/// Handles:
164/// - Deserializing `serde_json::Value` into `T::Args`
165/// - Calling `T::call(args, ctx)`
166/// - Serializing `T::Output` into `ToolOutput`
167/// - Mapping `T::Error` into `ToolError::ExecutionFailed`
168impl<T: Tool> ToolDyn for T {
169    fn name(&self) -> &str {
170        T::NAME
171    }
172
173    fn definition(&self) -> ToolDefinition {
174        Tool::definition(self)
175    }
176
177    fn call_dyn<'a>(
178        &'a self,
179        input: serde_json::Value,
180        ctx: &'a ToolContext,
181    ) -> WasmBoxedFuture<'a, Result<ToolOutput, ToolError>> {
182        Box::pin(async move {
183            let args: T::Args = serde_json::from_value(input)
184                .map_err(|e| ToolError::InvalidInput(e.to_string()))?;
185
186            let output = self
187                .call(args, ctx)
188                .await
189                .map_err(|e| ToolError::ExecutionFailed(e.to_string().into()))?;
190
191            let structured = serde_json::to_value(&output)
192                .map_err(|e| ToolError::ExecutionFailed(Box::new(e)))?;
193
194            let text = match &structured {
195                serde_json::Value::String(s) => s.clone(),
196                other => other.to_string(),
197            };
198
199            Ok(ToolOutput {
200                content: vec![ContentItem::Text(text)],
201                structured_content: Some(structured),
202                is_error: false,
203            })
204        })
205    }
206}
207
208// --- Context Strategy ---
209
210/// Strategy for compacting conversation context when it exceeds token limits.
211///
212/// Implementations decide when to compact and how to reduce the message list.
213///
214/// # Example
215///
216/// ```no_run
217/// use std::future::Future;
218/// use neuron_types::*;
219///
220/// struct KeepLastN { n: usize }
221/// impl ContextStrategy for KeepLastN {
222///     fn should_compact(&self, _messages: &[Message], token_count: usize) -> bool {
223///         token_count > 100_000
224///     }
225///     fn compact(&self, messages: Vec<Message>)
226///         -> impl Future<Output = Result<Vec<Message>, ContextError>> + Send
227///     {
228///         async move { Ok(messages.into_iter().rev().take(self.n).collect()) }
229///     }
230///     fn token_estimate(&self, messages: &[Message]) -> usize { messages.len() * 100 }
231/// }
232/// ```
233pub trait ContextStrategy: WasmCompatSend + WasmCompatSync {
234    /// Whether compaction should be triggered given the current messages and token count.
235    fn should_compact(&self, messages: &[Message], token_count: usize) -> bool;
236
237    /// Compact the message list to reduce token usage.
238    fn compact(
239        &self,
240        messages: Vec<Message>,
241    ) -> impl Future<Output = Result<Vec<Message>, ContextError>> + WasmCompatSend;
242
243    /// Estimate the token count for a list of messages.
244    fn token_estimate(&self, messages: &[Message]) -> usize;
245}
246
247// --- Observability Hooks ---
248
249/// Events fired during the agentic loop for observability.
250#[derive(Debug)]
251pub enum HookEvent<'a> {
252    /// Start of a loop iteration.
253    LoopIteration {
254        /// The current turn number (0-indexed).
255        turn: usize,
256    },
257    /// Before calling the LLM provider.
258    PreLlmCall {
259        /// The request about to be sent.
260        request: &'a CompletionRequest,
261    },
262    /// After receiving the LLM response.
263    PostLlmCall {
264        /// The response received.
265        response: &'a CompletionResponse,
266    },
267    /// Before executing a tool.
268    PreToolExecution {
269        /// Name of the tool.
270        tool_name: &'a str,
271        /// Input arguments.
272        input: &'a serde_json::Value,
273    },
274    /// After executing a tool.
275    PostToolExecution {
276        /// Name of the tool.
277        tool_name: &'a str,
278        /// The tool's output.
279        output: &'a ToolOutput,
280    },
281    /// Context was compacted.
282    ContextCompaction {
283        /// Token count before compaction.
284        old_tokens: usize,
285        /// Token count after compaction.
286        new_tokens: usize,
287    },
288    /// A session started.
289    SessionStart {
290        /// The session identifier.
291        session_id: &'a str,
292    },
293    /// A session ended.
294    SessionEnd {
295        /// The session identifier.
296        session_id: &'a str,
297    },
298}
299
300/// Action to take after processing a hook event.
301#[derive(Debug)]
302pub enum HookAction {
303    /// Continue normal execution.
304    Continue,
305    /// Skip the current operation and return a rejection message.
306    Skip {
307        /// Reason for skipping.
308        reason: String,
309    },
310    /// Terminate the loop immediately.
311    Terminate {
312        /// Reason for termination.
313        reason: String,
314    },
315}
316
317/// Hook for observability (logging, metrics, telemetry).
318///
319/// Does NOT control execution flow beyond Continue/Skip/Terminate.
320/// For durable execution wrapping, use [`DurableContext`] instead.
321///
322/// # Example
323///
324/// ```no_run
325/// use std::future::Future;
326/// use neuron_types::*;
327///
328/// struct LogHook;
329/// impl ObservabilityHook for LogHook {
330///     fn on_event(&self, event: HookEvent<'_>)
331///         -> impl Future<Output = Result<HookAction, HookError>> + Send
332///     {
333///         async move { println!("{event:?}"); Ok(HookAction::Continue) }
334///     }
335/// }
336/// ```
337pub trait ObservabilityHook: WasmCompatSend + WasmCompatSync {
338    /// Called for each event in the agentic loop.
339    fn on_event(
340        &self,
341        event: HookEvent<'_>,
342    ) -> impl Future<Output = Result<HookAction, HookError>> + WasmCompatSend;
343}
344
345// --- Durable Context ---
346
347/// Options for executing an activity in a durable context.
348#[derive(Debug, Clone)]
349pub struct ActivityOptions {
350    /// Maximum time for the activity to complete.
351    pub start_to_close_timeout: Duration,
352    /// Heartbeat interval for long-running activities.
353    pub heartbeat_timeout: Option<Duration>,
354    /// Retry policy for failed activities.
355    pub retry_policy: Option<RetryPolicy>,
356}
357
358/// Retry policy for durable activities.
359#[derive(Debug, Clone)]
360pub struct RetryPolicy {
361    /// Initial delay before first retry.
362    pub initial_interval: Duration,
363    /// Multiplier for exponential backoff.
364    pub backoff_coefficient: f64,
365    /// Maximum number of retry attempts.
366    pub maximum_attempts: u32,
367    /// Maximum delay between retries.
368    pub maximum_interval: Duration,
369    /// Error types that should not be retried.
370    pub non_retryable_errors: Vec<String>,
371}
372
373/// Wraps side effects for durable execution engines (Temporal, Restate, Inngest).
374///
375/// When present, the agentic loop calls through this instead of directly
376/// calling provider/tools, enabling journaling, replay, and crash recovery.
377pub trait DurableContext: WasmCompatSend + WasmCompatSync {
378    /// Execute an LLM call as a durable activity.
379    fn execute_llm_call(
380        &self,
381        request: CompletionRequest,
382        options: ActivityOptions,
383    ) -> impl Future<Output = Result<CompletionResponse, DurableError>> + WasmCompatSend;
384
385    /// Execute a tool call as a durable activity.
386    fn execute_tool(
387        &self,
388        tool_name: &str,
389        input: serde_json::Value,
390        ctx: &ToolContext,
391        options: ActivityOptions,
392    ) -> impl Future<Output = Result<ToolOutput, DurableError>> + WasmCompatSend;
393
394    /// Wait for an external signal with a timeout.
395    fn wait_for_signal<T: DeserializeOwned + WasmCompatSend>(
396        &self,
397        signal_name: &str,
398        timeout: Duration,
399    ) -> impl Future<Output = Result<Option<T>, DurableError>> + WasmCompatSend;
400
401    /// Whether the workflow should continue-as-new to avoid history bloat.
402    fn should_continue_as_new(&self) -> bool;
403
404    /// Continue the workflow as a new execution with the given state.
405    fn continue_as_new(
406        &self,
407        state: serde_json::Value,
408    ) -> impl Future<Output = Result<(), DurableError>> + WasmCompatSend;
409
410    /// Sleep for a duration (durable — survives replay).
411    fn sleep(&self, duration: Duration) -> impl Future<Output = ()> + WasmCompatSend;
412
413    /// Current time (deterministic during replay).
414    fn now(&self) -> chrono::DateTime<chrono::Utc>;
415}
416
417// --- Permission Policy ---
418
419/// Decision from a permission check.
420#[derive(Debug, Clone)]
421pub enum PermissionDecision {
422    /// Allow the tool call.
423    Allow,
424    /// Deny the tool call with a reason.
425    Deny(String),
426    /// Ask the user for confirmation.
427    Ask(String),
428}
429
430/// Policy for checking tool call permissions.
431///
432/// # Example
433///
434/// ```no_run
435/// use neuron_types::*;
436///
437/// struct AllowAll;
438/// impl PermissionPolicy for AllowAll {
439///     fn check(&self, _tool_name: &str, _input: &serde_json::Value) -> PermissionDecision {
440///         PermissionDecision::Allow
441///     }
442/// }
443/// ```
444pub trait PermissionPolicy: WasmCompatSend + WasmCompatSync {
445    /// Check whether a tool call is permitted.
446    fn check(&self, tool_name: &str, input: &serde_json::Value) -> PermissionDecision;
447}