neuron_types/traits.rs
1//! Core traits: Provider, Tool, ToolDyn, ContextStrategy, ObservabilityHook, DurableContext.
2
3use std::future::Future;
4use std::time::Duration;
5
6use serde::de::DeserializeOwned;
7use serde::Serialize;
8
9use crate::error::{ContextError, DurableError, EmbeddingError, HookError, ProviderError, ToolError};
10use crate::stream::StreamHandle;
11use crate::types::{
12 CompletionRequest, CompletionResponse, ContentItem, EmbeddingRequest, EmbeddingResponse,
13 Message, ToolContext, ToolDefinition, ToolOutput,
14};
15use crate::wasm::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync};
16
17/// LLM provider trait. Implement this for each provider (Anthropic, OpenAI, Ollama, etc.).
18///
19/// Uses RPITIT (return position impl trait in trait) — Rust 2024 native async.
20/// Not object-safe by design; use generics `<P: Provider>` to compose.
21///
22/// # Example
23///
24/// ```no_run
25/// use std::future::Future;
26/// use neuron_types::*;
27///
28/// struct MyProvider;
29///
30/// impl Provider for MyProvider {
31/// fn complete(&self, request: CompletionRequest)
32/// -> impl Future<Output = Result<CompletionResponse, ProviderError>> + Send
33/// {
34/// async { todo!() }
35/// }
36///
37/// fn complete_stream(&self, request: CompletionRequest)
38/// -> impl Future<Output = Result<StreamHandle, ProviderError>> + Send
39/// {
40/// async { todo!() }
41/// }
42/// }
43/// ```
44pub trait Provider: WasmCompatSend + WasmCompatSync {
45 /// Send a completion request and get a full response.
46 fn complete(
47 &self,
48 request: CompletionRequest,
49 ) -> impl Future<Output = Result<CompletionResponse, ProviderError>> + WasmCompatSend;
50
51 /// Send a completion request and get a stream of events.
52 fn complete_stream(
53 &self,
54 request: CompletionRequest,
55 ) -> impl Future<Output = Result<StreamHandle, ProviderError>> + WasmCompatSend;
56}
57
58/// Embedding provider trait. Implement this for providers that support text embeddings.
59///
60/// Kept separate from [`Provider`] because not all embedding models support chat completion
61/// and not all chat providers support embeddings. Implement both traits on a struct when a
62/// provider supports both capabilities.
63///
64/// Uses RPITIT (return position impl trait in trait) — Rust 2024 native async.
65/// Not object-safe by design; use generics `<E: EmbeddingProvider>` to compose.
66///
67/// # Example
68///
69/// ```no_run
70/// use std::future::Future;
71/// use neuron_types::*;
72///
73/// struct MyEmbeddingProvider;
74///
75/// impl EmbeddingProvider for MyEmbeddingProvider {
76/// fn embed(&self, request: EmbeddingRequest)
77/// -> impl Future<Output = Result<EmbeddingResponse, EmbeddingError>> + Send
78/// {
79/// async { todo!() }
80/// }
81/// }
82/// ```
83pub trait EmbeddingProvider: WasmCompatSend + WasmCompatSync {
84 /// Send an embedding request and get vectors back.
85 ///
86 /// Multiple input strings are batched into a single request. The returned
87 /// `embeddings` vec is in the same order as `request.input`.
88 fn embed(
89 &self,
90 request: EmbeddingRequest,
91 ) -> impl Future<Output = Result<EmbeddingResponse, EmbeddingError>> + WasmCompatSend;
92}
93
94/// Strongly-typed tool trait. Implement this for your tools.
95///
96/// The blanket impl of [`ToolDyn`] handles JSON deserialization/serialization
97/// so you work with concrete Rust types.
98///
99/// # Example
100///
101/// ```no_run
102/// use neuron_types::*;
103/// use serde::Deserialize;
104///
105/// #[derive(Debug, Deserialize, schemars::JsonSchema)]
106/// struct MyArgs { query: String }
107///
108/// struct MyTool;
109/// impl Tool for MyTool {
110/// const NAME: &'static str = "my_tool";
111/// type Args = MyArgs;
112/// type Output = String;
113/// type Error = std::io::Error;
114///
115/// fn definition(&self) -> ToolDefinition { todo!() }
116/// fn call(&self, args: MyArgs, ctx: &ToolContext)
117/// -> impl Future<Output = Result<String, std::io::Error>> + Send
118/// { async { Ok(args.query) } }
119/// }
120/// ```
121pub trait Tool: WasmCompatSend + WasmCompatSync {
122 /// The unique name of this tool.
123 const NAME: &'static str;
124 /// The deserialized input type.
125 type Args: DeserializeOwned + schemars::JsonSchema + WasmCompatSend;
126 /// The serializable output type.
127 type Output: Serialize;
128 /// The tool-specific error type.
129 type Error: std::error::Error + WasmCompatSend + 'static;
130
131 /// Returns the tool definition (name, description, schema).
132 fn definition(&self) -> ToolDefinition;
133
134 /// Execute the tool with typed arguments.
135 fn call(
136 &self,
137 args: Self::Args,
138 ctx: &ToolContext,
139 ) -> impl Future<Output = Result<Self::Output, Self::Error>> + WasmCompatSend;
140}
141
142/// Type-erased tool for dynamic dispatch. Blanket-implemented for all [`Tool`] impls.
143///
144/// This enables heterogeneous tool collections (`HashMap<String, Arc<dyn ToolDyn>>`)
145/// while preserving type safety at the implementation level.
146pub trait ToolDyn: WasmCompatSend + WasmCompatSync {
147 /// The tool's unique name.
148 fn name(&self) -> &str;
149 /// The tool definition (name, description, input schema).
150 fn definition(&self) -> ToolDefinition;
151 /// Execute the tool with a JSON value input, returning a generic output.
152 fn call_dyn<'a>(
153 &'a self,
154 input: serde_json::Value,
155 ctx: &'a ToolContext,
156 ) -> WasmBoxedFuture<'a, Result<ToolOutput, ToolError>>;
157}
158
159/// Blanket implementation: any `Tool` automatically becomes a `ToolDyn`.
160///
161/// Handles:
162/// - Deserializing `serde_json::Value` into `T::Args`
163/// - Calling `T::call(args, ctx)`
164/// - Serializing `T::Output` into `ToolOutput`
165/// - Mapping `T::Error` into `ToolError::ExecutionFailed`
166impl<T: Tool> ToolDyn for T {
167 fn name(&self) -> &str {
168 T::NAME
169 }
170
171 fn definition(&self) -> ToolDefinition {
172 Tool::definition(self)
173 }
174
175 fn call_dyn<'a>(
176 &'a self,
177 input: serde_json::Value,
178 ctx: &'a ToolContext,
179 ) -> WasmBoxedFuture<'a, Result<ToolOutput, ToolError>> {
180 Box::pin(async move {
181 let args: T::Args = serde_json::from_value(input)
182 .map_err(|e| ToolError::InvalidInput(e.to_string()))?;
183
184 let output = self.call(args, ctx).await.map_err(|e| {
185 ToolError::ExecutionFailed(e.to_string().into())
186 })?;
187
188 let structured = serde_json::to_value(&output)
189 .map_err(|e| ToolError::ExecutionFailed(Box::new(e)))?;
190
191 let text = match &structured {
192 serde_json::Value::String(s) => s.clone(),
193 other => other.to_string(),
194 };
195
196 Ok(ToolOutput {
197 content: vec![ContentItem::Text(text)],
198 structured_content: Some(structured),
199 is_error: false,
200 })
201 })
202 }
203}
204
205// --- Context Strategy ---
206
207/// Strategy for compacting conversation context when it exceeds token limits.
208///
209/// Implementations decide when to compact and how to reduce the message list.
210///
211/// # Example
212///
213/// ```no_run
214/// use std::future::Future;
215/// use neuron_types::*;
216///
217/// struct KeepLastN { n: usize }
218/// impl ContextStrategy for KeepLastN {
219/// fn should_compact(&self, _messages: &[Message], token_count: usize) -> bool {
220/// token_count > 100_000
221/// }
222/// fn compact(&self, messages: Vec<Message>)
223/// -> impl Future<Output = Result<Vec<Message>, ContextError>> + Send
224/// {
225/// async move { Ok(messages.into_iter().rev().take(self.n).collect()) }
226/// }
227/// fn token_estimate(&self, messages: &[Message]) -> usize { messages.len() * 100 }
228/// }
229/// ```
230pub trait ContextStrategy: WasmCompatSend + WasmCompatSync {
231 /// Whether compaction should be triggered given the current messages and token count.
232 fn should_compact(&self, messages: &[Message], token_count: usize) -> bool;
233
234 /// Compact the message list to reduce token usage.
235 fn compact(
236 &self,
237 messages: Vec<Message>,
238 ) -> impl Future<Output = Result<Vec<Message>, ContextError>> + WasmCompatSend;
239
240 /// Estimate the token count for a list of messages.
241 fn token_estimate(&self, messages: &[Message]) -> usize;
242}
243
244// --- Observability Hooks ---
245
246/// Events fired during the agentic loop for observability.
247#[derive(Debug)]
248pub enum HookEvent<'a> {
249 /// Start of a loop iteration.
250 LoopIteration {
251 /// The current turn number (0-indexed).
252 turn: usize,
253 },
254 /// Before calling the LLM provider.
255 PreLlmCall {
256 /// The request about to be sent.
257 request: &'a CompletionRequest,
258 },
259 /// After receiving the LLM response.
260 PostLlmCall {
261 /// The response received.
262 response: &'a CompletionResponse,
263 },
264 /// Before executing a tool.
265 PreToolExecution {
266 /// Name of the tool.
267 tool_name: &'a str,
268 /// Input arguments.
269 input: &'a serde_json::Value,
270 },
271 /// After executing a tool.
272 PostToolExecution {
273 /// Name of the tool.
274 tool_name: &'a str,
275 /// The tool's output.
276 output: &'a ToolOutput,
277 },
278 /// Context was compacted.
279 ContextCompaction {
280 /// Token count before compaction.
281 old_tokens: usize,
282 /// Token count after compaction.
283 new_tokens: usize,
284 },
285 /// A session started.
286 SessionStart {
287 /// The session identifier.
288 session_id: &'a str,
289 },
290 /// A session ended.
291 SessionEnd {
292 /// The session identifier.
293 session_id: &'a str,
294 },
295}
296
297/// Action to take after processing a hook event.
298#[derive(Debug)]
299pub enum HookAction {
300 /// Continue normal execution.
301 Continue,
302 /// Skip the current operation and return a rejection message.
303 Skip {
304 /// Reason for skipping.
305 reason: String,
306 },
307 /// Terminate the loop immediately.
308 Terminate {
309 /// Reason for termination.
310 reason: String,
311 },
312}
313
314/// Hook for observability (logging, metrics, telemetry).
315///
316/// Does NOT control execution flow beyond Continue/Skip/Terminate.
317/// For durable execution wrapping, use [`DurableContext`] instead.
318///
319/// # Example
320///
321/// ```no_run
322/// use std::future::Future;
323/// use neuron_types::*;
324///
325/// struct LogHook;
326/// impl ObservabilityHook for LogHook {
327/// fn on_event(&self, event: HookEvent<'_>)
328/// -> impl Future<Output = Result<HookAction, HookError>> + Send
329/// {
330/// async move { println!("{event:?}"); Ok(HookAction::Continue) }
331/// }
332/// }
333/// ```
334pub trait ObservabilityHook: WasmCompatSend + WasmCompatSync {
335 /// Called for each event in the agentic loop.
336 fn on_event(
337 &self,
338 event: HookEvent<'_>,
339 ) -> impl Future<Output = Result<HookAction, HookError>> + WasmCompatSend;
340}
341
342// --- Durable Context ---
343
344/// Options for executing an activity in a durable context.
345#[derive(Debug, Clone)]
346pub struct ActivityOptions {
347 /// Maximum time for the activity to complete.
348 pub start_to_close_timeout: Duration,
349 /// Heartbeat interval for long-running activities.
350 pub heartbeat_timeout: Option<Duration>,
351 /// Retry policy for failed activities.
352 pub retry_policy: Option<RetryPolicy>,
353}
354
355/// Retry policy for durable activities.
356#[derive(Debug, Clone)]
357pub struct RetryPolicy {
358 /// Initial delay before first retry.
359 pub initial_interval: Duration,
360 /// Multiplier for exponential backoff.
361 pub backoff_coefficient: f64,
362 /// Maximum number of retry attempts.
363 pub maximum_attempts: u32,
364 /// Maximum delay between retries.
365 pub maximum_interval: Duration,
366 /// Error types that should not be retried.
367 pub non_retryable_errors: Vec<String>,
368}
369
370/// Wraps side effects for durable execution engines (Temporal, Restate, Inngest).
371///
372/// When present, the agentic loop calls through this instead of directly
373/// calling provider/tools, enabling journaling, replay, and crash recovery.
374pub trait DurableContext: WasmCompatSend + WasmCompatSync {
375 /// Execute an LLM call as a durable activity.
376 fn execute_llm_call(
377 &self,
378 request: CompletionRequest,
379 options: ActivityOptions,
380 ) -> impl Future<Output = Result<CompletionResponse, DurableError>> + WasmCompatSend;
381
382 /// Execute a tool call as a durable activity.
383 fn execute_tool(
384 &self,
385 tool_name: &str,
386 input: serde_json::Value,
387 ctx: &ToolContext,
388 options: ActivityOptions,
389 ) -> impl Future<Output = Result<ToolOutput, DurableError>> + WasmCompatSend;
390
391 /// Wait for an external signal with a timeout.
392 fn wait_for_signal<T: DeserializeOwned + WasmCompatSend>(
393 &self,
394 signal_name: &str,
395 timeout: Duration,
396 ) -> impl Future<Output = Result<Option<T>, DurableError>> + WasmCompatSend;
397
398 /// Whether the workflow should continue-as-new to avoid history bloat.
399 fn should_continue_as_new(&self) -> bool;
400
401 /// Continue the workflow as a new execution with the given state.
402 fn continue_as_new(
403 &self,
404 state: serde_json::Value,
405 ) -> impl Future<Output = Result<(), DurableError>> + WasmCompatSend;
406
407 /// Sleep for a duration (durable — survives replay).
408 fn sleep(
409 &self,
410 duration: Duration,
411 ) -> impl Future<Output = ()> + WasmCompatSend;
412
413 /// Current time (deterministic during replay).
414 fn now(&self) -> chrono::DateTime<chrono::Utc>;
415}
416
417// --- Permission Policy ---
418
419/// Decision from a permission check.
420#[derive(Debug, Clone)]
421pub enum PermissionDecision {
422 /// Allow the tool call.
423 Allow,
424 /// Deny the tool call with a reason.
425 Deny(String),
426 /// Ask the user for confirmation.
427 Ask(String),
428}
429
430/// Policy for checking tool call permissions.
431///
432/// # Example
433///
434/// ```no_run
435/// use neuron_types::*;
436///
437/// struct AllowAll;
438/// impl PermissionPolicy for AllowAll {
439/// fn check(&self, _tool_name: &str, _input: &serde_json::Value) -> PermissionDecision {
440/// PermissionDecision::Allow
441/// }
442/// }
443/// ```
444pub trait PermissionPolicy: WasmCompatSend + WasmCompatSync {
445 /// Check whether a tool call is permitted.
446 fn check(&self, tool_name: &str, input: &serde_json::Value) -> PermissionDecision;
447}