use std::collections::HashMap;
use std::marker::PhantomData;
use std::sync::{Arc, Mutex};
use futures::StreamExt;
use tokio::sync::mpsc;
use tracing::{debug, info, trace, warn};
use crate::{
ContentPart, Message, MessageContent, Role,
hook::{
Hook, HookError, HookRegistry, OnAbort, OnPromptSubmit, OnPromptSubmitResult, OnTurnEnd,
OnTurnEndResult, PostToolCall, PostToolCallContext, PostToolCallResult, PreLlmRequest,
PreLlmRequestResult, PreToolCall, PreToolCallResult, ToolCall, ToolCallContext, ToolResult,
},
llm_client::{
ClientError, ConfigWarning, LlmClient, Request, RequestConfig,
ToolDefinition as LlmToolDefinition,
},
state::{CacheLocked, Mutable, WorkerState},
subscriber::{
ErrorSubscriberAdapter, StatusSubscriberAdapter, TextBlockSubscriberAdapter,
ToolUseBlockSubscriberAdapter, UsageSubscriberAdapter, WorkerSubscriber,
},
timeline::{TextBlockCollector, Timeline, ToolCallCollector},
tool::{Tool, ToolDefinition, ToolError, ToolMeta},
};
#[derive(Debug, thiserror::Error)]
pub enum WorkerError {
#[error("Client error: {0}")]
Client(#[from] ClientError),
#[error("Tool error: {0}")]
Tool(#[from] ToolError),
#[error("Hook error: {0}")]
Hook(#[from] HookError),
#[error("Aborted: {0}")]
Aborted(String),
#[error("Cancelled")]
Cancelled,
#[error("Config warnings: {}", .0.iter().map(|w| w.to_string()).collect::<Vec<_>>().join(", "))]
ConfigWarnings(Vec<ConfigWarning>),
}
#[derive(Debug, thiserror::Error)]
pub enum ToolRegistryError {
#[error("Tool with name '{0}' already registered")]
DuplicateName(String),
}
#[derive(Debug, Clone, Default)]
pub struct WorkerConfig {
_private: (),
}
#[derive(Debug)]
pub enum WorkerResult {
Finished,
Paused,
}
enum ToolExecutionResult {
Completed(Vec<ToolResult>),
Paused,
}
trait TurnNotifier: Send + Sync {
fn on_turn_start(&self, turn: usize);
fn on_turn_end(&self, turn: usize);
}
struct SubscriberTurnNotifier<S: WorkerSubscriber + 'static> {
subscriber: Arc<Mutex<S>>,
}
impl<S: WorkerSubscriber + 'static> TurnNotifier for SubscriberTurnNotifier<S> {
fn on_turn_start(&self, turn: usize) {
if let Ok(mut s) = self.subscriber.lock() {
s.on_turn_start(turn);
}
}
fn on_turn_end(&self, turn: usize) {
if let Ok(mut s) = self.subscriber.lock() {
s.on_turn_end(turn);
}
}
}
pub struct Worker<C: LlmClient, S: WorkerState = Mutable> {
client: C,
timeline: Timeline,
text_block_collector: TextBlockCollector,
tool_call_collector: ToolCallCollector,
tools: HashMap<String, (ToolMeta, Arc<dyn Tool>)>,
hooks: HookRegistry,
system_prompt: Option<String>,
history: Vec<Message>,
locked_prefix_len: usize,
turn_count: usize,
turn_notifiers: Vec<Box<dyn TurnNotifier>>,
request_config: RequestConfig,
last_run_interrupted: bool,
cancel_tx: mpsc::Sender<()>,
cancel_rx: mpsc::Receiver<()>,
_state: PhantomData<S>,
}
impl<C: LlmClient, S: WorkerState> Worker<C, S> {
fn reset_interruption_state(&mut self) {
self.last_run_interrupted = false;
}
pub async fn run(
&mut self,
user_input: impl Into<String>,
) -> Result<WorkerResult, WorkerError> {
self.reset_interruption_state();
let mut user_message = Message::user(user_input);
let result = self.run_on_prompt_submit_hooks(&mut user_message).await;
let result = match result {
Ok(value) => value,
Err(err) => return self.finalize_interruption(Err(err)).await,
};
match result {
OnPromptSubmitResult::Cancel(reason) => {
self.last_run_interrupted = true;
return self.finalize_interruption(Err(WorkerError::Aborted(reason))).await;
}
OnPromptSubmitResult::Continue => {}
}
self.history.push(user_message);
let result = self.run_turn_loop().await;
self.finalize_interruption(result).await
}
fn drain_cancel_queue(&mut self) {
use tokio::sync::mpsc::error::TryRecvError;
loop {
match self.cancel_rx.try_recv() {
Ok(()) => continue,
Err(TryRecvError::Empty) | Err(TryRecvError::Disconnected) => break,
}
}
}
fn try_cancelled(&mut self) -> bool {
use tokio::sync::mpsc::error::TryRecvError;
match self.cancel_rx.try_recv() {
Ok(()) => true,
Err(TryRecvError::Empty) => false,
Err(TryRecvError::Disconnected) => true,
}
}
pub fn subscribe<Sub: WorkerSubscriber + 'static>(&mut self, subscriber: Sub) {
let subscriber = Arc::new(Mutex::new(subscriber));
self.timeline
.on_text_block(TextBlockSubscriberAdapter::new(subscriber.clone()));
self.timeline
.on_tool_use_block(ToolUseBlockSubscriberAdapter::new(subscriber.clone()));
self.timeline
.on_usage(UsageSubscriberAdapter::new(subscriber.clone()));
self.timeline
.on_status(StatusSubscriberAdapter::new(subscriber.clone()));
self.timeline
.on_error(ErrorSubscriberAdapter::new(subscriber.clone()));
self.turn_notifiers
.push(Box::new(SubscriberTurnNotifier { subscriber }));
}
pub fn register_tool(&mut self, factory: ToolDefinition) -> Result<(), ToolRegistryError> {
let (meta, instance) = factory();
if self.tools.contains_key(&meta.name) {
return Err(ToolRegistryError::DuplicateName(meta.name.clone()));
}
self.tools.insert(meta.name.clone(), (meta, instance));
Ok(())
}
pub fn register_tools(
&mut self,
factories: impl IntoIterator<Item = ToolDefinition>,
) -> Result<(), ToolRegistryError> {
for factory in factories {
self.register_tool(factory)?;
}
Ok(())
}
pub fn add_on_prompt_submit_hook(&mut self, hook: impl Hook<OnPromptSubmit> + 'static) {
self.hooks.on_prompt_submit.push(Box::new(hook));
}
pub fn add_pre_llm_request_hook(&mut self, hook: impl Hook<PreLlmRequest> + 'static) {
self.hooks.pre_llm_request.push(Box::new(hook));
}
pub fn add_pre_tool_call_hook(&mut self, hook: impl Hook<PreToolCall> + 'static) {
self.hooks.pre_tool_call.push(Box::new(hook));
}
pub fn add_post_tool_call_hook(&mut self, hook: impl Hook<PostToolCall> + 'static) {
self.hooks.post_tool_call.push(Box::new(hook));
}
pub fn add_on_turn_end_hook(&mut self, hook: impl Hook<OnTurnEnd> + 'static) {
self.hooks.on_turn_end.push(Box::new(hook));
}
pub fn add_on_abort_hook(&mut self, hook: impl Hook<OnAbort> + 'static) {
self.hooks.on_abort.push(Box::new(hook));
}
pub fn timeline_mut(&mut self) -> &mut Timeline {
&mut self.timeline
}
pub fn history(&self) -> &[Message] {
&self.history
}
pub fn get_system_prompt(&self) -> Option<&str> {
self.system_prompt.as_deref()
}
pub fn turn_count(&self) -> usize {
self.turn_count
}
pub fn request_config(&self) -> &RequestConfig {
&self.request_config
}
pub fn set_max_tokens(&mut self, max_tokens: u32) {
self.request_config.max_tokens = Some(max_tokens);
}
pub fn set_temperature(&mut self, temperature: f32) {
self.request_config.temperature = Some(temperature);
}
pub fn set_top_p(&mut self, top_p: f32) {
self.request_config.top_p = Some(top_p);
}
pub fn set_top_k(&mut self, top_k: u32) {
self.request_config.top_k = Some(top_k);
}
pub fn add_stop_sequence(&mut self, sequence: impl Into<String>) {
self.request_config.stop_sequences.push(sequence.into());
}
pub fn clear_stop_sequences(&mut self) {
self.request_config.stop_sequences.clear();
}
pub fn cancel_sender(&self) -> mpsc::Sender<()> {
self.cancel_tx.clone()
}
pub fn set_request_config(&mut self, config: RequestConfig) {
self.request_config = config;
}
pub fn cancel(&self) {
let _ = self.cancel_tx.try_send(());
}
pub fn is_cancelled(&mut self) -> bool {
self.try_cancelled()
}
pub fn last_run_interrupted(&self) -> bool {
self.last_run_interrupted
}
fn build_tool_definitions(&self) -> Vec<LlmToolDefinition> {
self.tools
.values()
.map(|(meta, _)| {
LlmToolDefinition::new(&meta.name)
.description(&meta.description)
.input_schema(meta.input_schema.clone())
})
.collect()
}
fn build_assistant_message(
&self,
text_blocks: &[String],
tool_calls: &[ToolCall],
) -> Option<Message> {
if text_blocks.is_empty() && tool_calls.is_empty() {
return None;
}
if tool_calls.is_empty() {
let text = text_blocks.join("");
return Some(Message::assistant(text));
}
let mut parts = Vec::new();
for text in text_blocks {
if !text.is_empty() {
parts.push(ContentPart::Text { text: text.clone() });
}
}
for call in tool_calls {
parts.push(ContentPart::ToolUse {
id: call.id.clone(),
name: call.name.clone(),
input: call.input.clone(),
});
}
Some(Message {
role: Role::Assistant,
content: MessageContent::Parts(parts),
})
}
fn build_request(
&self,
tool_definitions: &[LlmToolDefinition],
context: &[Message],
) -> Request {
let mut request = Request::new();
if let Some(ref system) = self.system_prompt {
request = request.system(system);
}
for msg in context {
request = request.message(crate::llm_client::Message {
role: match msg.role {
Role::User => crate::llm_client::Role::User,
Role::Assistant => crate::llm_client::Role::Assistant,
},
content: match &msg.content {
MessageContent::Text(t) => crate::llm_client::MessageContent::Text(t.clone()),
MessageContent::ToolResult {
tool_use_id,
content,
} => crate::llm_client::MessageContent::ToolResult {
tool_use_id: tool_use_id.clone(),
content: content.clone(),
},
MessageContent::Parts(parts) => crate::llm_client::MessageContent::Parts(
parts
.iter()
.map(|p| match p {
ContentPart::Text { text } => {
crate::llm_client::ContentPart::Text { text: text.clone() }
}
ContentPart::ToolUse { id, name, input } => {
crate::llm_client::ContentPart::ToolUse {
id: id.clone(),
name: name.clone(),
input: input.clone(),
}
}
ContentPart::ToolResult {
tool_use_id,
content,
} => crate::llm_client::ContentPart::ToolResult {
tool_use_id: tool_use_id.clone(),
content: content.clone(),
},
})
.collect(),
),
},
});
}
for tool_def in tool_definitions {
request = request.tool(tool_def.clone());
}
request = request.config(self.request_config.clone());
request
}
async fn run_on_prompt_submit_hooks(
&self,
message: &mut Message,
) -> Result<OnPromptSubmitResult, WorkerError> {
for hook in &self.hooks.on_prompt_submit {
let result = hook.call(message).await?;
match result {
OnPromptSubmitResult::Continue => continue,
OnPromptSubmitResult::Cancel(reason) => {
return Ok(OnPromptSubmitResult::Cancel(reason));
}
}
}
Ok(OnPromptSubmitResult::Continue)
}
async fn run_pre_llm_request_hooks(
&self,
) -> Result<(PreLlmRequestResult, Vec<Message>), WorkerError> {
let mut temp_context = self.history.clone();
for hook in &self.hooks.pre_llm_request {
let result = hook.call(&mut temp_context).await?;
match result {
PreLlmRequestResult::Continue => continue,
PreLlmRequestResult::Cancel(reason) => {
return Ok((PreLlmRequestResult::Cancel(reason), temp_context));
}
}
}
Ok((PreLlmRequestResult::Continue, temp_context))
}
async fn run_on_turn_end_hooks(&self) -> Result<OnTurnEndResult, WorkerError> {
let mut temp_messages = self.history.clone();
for hook in &self.hooks.on_turn_end {
let result = hook.call(&mut temp_messages).await?;
match result {
OnTurnEndResult::Finish => continue,
OnTurnEndResult::ContinueWithMessages(msgs) => {
return Ok(OnTurnEndResult::ContinueWithMessages(msgs));
}
OnTurnEndResult::Paused => return Ok(OnTurnEndResult::Paused),
}
}
Ok(OnTurnEndResult::Finish)
}
async fn run_on_abort_hooks(&self, reason: &str) -> Result<(), WorkerError> {
let mut reason = reason.to_string();
for hook in &self.hooks.on_abort {
hook.call(&mut reason).await?;
}
Ok(())
}
async fn finalize_interruption<T>(
&mut self,
result: Result<T, WorkerError>,
) -> Result<T, WorkerError> {
match result {
Ok(value) => Ok(value),
Err(err) => {
self.last_run_interrupted = true;
let reason = match &err {
WorkerError::Aborted(reason) => reason.clone(),
WorkerError::Cancelled => "Cancelled".to_string(),
_ => err.to_string(),
};
if let Err(hook_err) = self.run_on_abort_hooks(&reason).await {
self.last_run_interrupted = true;
return Err(hook_err);
}
Err(err)
}
}
}
fn get_pending_tool_calls(&self) -> Option<Vec<ToolCall>> {
let last_msg = self.history.last()?;
if last_msg.role != Role::Assistant {
return None;
}
let mut calls = Vec::new();
if let MessageContent::Parts(parts) = &last_msg.content {
for part in parts {
if let ContentPart::ToolUse { id, name, input } = part {
calls.push(ToolCall {
id: id.clone(),
name: name.clone(),
input: input.clone(),
});
}
}
}
if calls.is_empty() { None } else { Some(calls) }
}
async fn execute_tools(
&mut self,
tool_calls: Vec<ToolCall>,
) -> Result<ToolExecutionResult, WorkerError> {
use futures::future::join_all;
let mut call_info_map = HashMap::new();
let mut approved_calls = Vec::new();
for mut tool_call in tool_calls {
if let Some((meta, tool)) = self.tools.get(&tool_call.name) {
let mut context = ToolCallContext {
call: tool_call.clone(),
meta: meta.clone(),
tool: tool.clone(),
};
let mut skip = false;
for hook in &self.hooks.pre_tool_call {
let result = hook
.call(&mut context)
.await
.inspect_err(|_| self.last_run_interrupted = true)?;
match result {
PreToolCallResult::Continue => {}
PreToolCallResult::Skip => {
skip = true;
break;
}
PreToolCallResult::Abort(reason) => {
self.last_run_interrupted = true;
return Err(WorkerError::Aborted(reason));
}
PreToolCallResult::Pause => {
self.last_run_interrupted = true;
return Ok(ToolExecutionResult::Paused);
}
}
}
tool_call = context.call;
if !skip {
call_info_map.insert(
tool_call.id.clone(),
(tool_call.clone(), meta.clone(), tool.clone()),
);
approved_calls.push(tool_call);
}
} else {
approved_calls.push(tool_call);
}
}
let futures: Vec<_> = approved_calls
.into_iter()
.map(|tool_call| {
let tools = &self.tools;
async move {
if let Some((_, tool)) = tools.get(&tool_call.name) {
let input_json =
serde_json::to_string(&tool_call.input).unwrap_or_default();
match tool.execute(&input_json).await {
Ok(content) => ToolResult::success(&tool_call.id, content),
Err(e) => ToolResult::error(&tool_call.id, e.to_string()),
}
} else {
ToolResult::error(
&tool_call.id,
format!("Tool '{}' not found", tool_call.name),
)
}
}
})
.collect();
let mut results = tokio::select! {
results = join_all(futures) => results,
cancel = self.cancel_rx.recv() => {
if cancel.is_some() {
info!("Tool execution cancelled");
}
self.timeline.abort_current_block();
self.last_run_interrupted = true;
return Err(WorkerError::Cancelled);
}
};
for tool_result in &mut results {
if let Some((tool_call, meta, tool)) = call_info_map.get(&tool_result.tool_use_id) {
let mut context = PostToolCallContext {
call: tool_call.clone(),
result: tool_result.clone(),
meta: meta.clone(),
tool: tool.clone(),
};
for hook in &self.hooks.post_tool_call {
let result = hook
.call(&mut context)
.await
.inspect_err(|_| self.last_run_interrupted = true)?;
match result {
PostToolCallResult::Continue => {}
PostToolCallResult::Abort(reason) => {
self.last_run_interrupted = true;
return Err(WorkerError::Aborted(reason));
}
}
}
*tool_result = context.result;
}
}
Ok(ToolExecutionResult::Completed(results))
}
async fn run_turn_loop(&mut self) -> Result<WorkerResult, WorkerError> {
self.reset_interruption_state();
self.drain_cancel_queue();
let tool_definitions = self.build_tool_definitions();
info!(
message_count = self.history.len(),
tool_count = tool_definitions.len(),
"Starting worker run"
);
if let Some(tool_calls) = self.get_pending_tool_calls() {
info!("Resuming pending tool calls");
match self.execute_tools(tool_calls).await {
Ok(ToolExecutionResult::Paused) => {
self.last_run_interrupted = true;
return Ok(WorkerResult::Paused);
}
Ok(ToolExecutionResult::Completed(results)) => {
for result in results {
self.history
.push(Message::tool_result(&result.tool_use_id, &result.content));
}
}
Err(err) => {
self.last_run_interrupted = true;
return Err(err);
}
}
}
loop {
if self.try_cancelled() {
info!("Execution cancelled");
self.timeline.abort_current_block();
self.last_run_interrupted = true;
return Err(WorkerError::Cancelled);
}
let current_turn = self.turn_count;
debug!(turn = current_turn, "Turn start");
for notifier in &self.turn_notifiers {
notifier.on_turn_start(current_turn);
}
let (control, request_context) = self
.run_pre_llm_request_hooks()
.await
.inspect_err(|_| self.last_run_interrupted = true)?;
match control {
PreLlmRequestResult::Cancel(reason) => {
info!(reason = %reason, "Aborted by hook");
for notifier in &self.turn_notifiers {
notifier.on_turn_end(current_turn);
}
self.last_run_interrupted = true;
return Err(WorkerError::Aborted(reason));
}
PreLlmRequestResult::Continue => {}
}
let request = self.build_request(&tool_definitions, &request_context);
debug!(
message_count = request.messages.len(),
tool_count = request.tools.len(),
has_system = request.system_prompt.is_some(),
"Sending request to LLM"
);
debug!("Starting stream...");
let mut event_count = 0;
let mut stream = tokio::select! {
stream_result = self.client.stream(request) => stream_result
.inspect_err(|_| self.last_run_interrupted = true)?,
cancel = self.cancel_rx.recv() => {
if cancel.is_some() {
info!("Cancelled before stream started");
}
self.timeline.abort_current_block();
self.last_run_interrupted = true;
return Err(WorkerError::Cancelled);
}
};
loop {
tokio::select! {
event_result = stream.next() => {
match event_result {
Some(result) => {
match &result {
Ok(event) => {
trace!(event = ?event, "Received event");
event_count += 1;
}
Err(e) => {
warn!(error = %e, "Stream error");
}
}
let event = result
.inspect_err(|_| self.last_run_interrupted = true)?;
let timeline_event: crate::timeline::event::Event = event.into();
self.timeline.dispatch(&timeline_event);
}
None => break, }
}
cancel = self.cancel_rx.recv() => {
if cancel.is_some() {
info!("Stream cancelled");
}
self.timeline.abort_current_block();
self.last_run_interrupted = true;
return Err(WorkerError::Cancelled);
}
}
}
debug!(event_count = event_count, "Stream completed");
for notifier in &self.turn_notifiers {
notifier.on_turn_end(current_turn);
}
self.turn_count += 1;
let text_blocks = self.text_block_collector.take_collected();
let tool_calls = self.tool_call_collector.take_collected();
let assistant_message = self.build_assistant_message(&text_blocks, &tool_calls);
if let Some(msg) = assistant_message {
self.history.push(msg);
}
if tool_calls.is_empty() {
let turn_result = self
.run_on_turn_end_hooks()
.await
.inspect_err(|_| self.last_run_interrupted = true)?;
match turn_result {
OnTurnEndResult::Finish => {
self.last_run_interrupted = false;
return Ok(WorkerResult::Finished);
}
OnTurnEndResult::ContinueWithMessages(additional) => {
self.history.extend(additional);
continue;
}
OnTurnEndResult::Paused => {
self.last_run_interrupted = true;
return Ok(WorkerResult::Paused);
}
}
}
match self.execute_tools(tool_calls).await {
Ok(ToolExecutionResult::Paused) => {
self.last_run_interrupted = true;
return Ok(WorkerResult::Paused);
}
Ok(ToolExecutionResult::Completed(results)) => {
for result in results {
self.history
.push(Message::tool_result(&result.tool_use_id, &result.content));
}
}
Err(err) => {
self.last_run_interrupted = true;
return Err(err);
}
}
}
}
pub async fn resume(&mut self) -> Result<WorkerResult, WorkerError> {
self.reset_interruption_state();
let result = self.run_turn_loop().await;
self.finalize_interruption(result).await
}
}
impl<C: LlmClient> Worker<C, Mutable> {
pub fn new(client: C) -> Self {
let text_block_collector = TextBlockCollector::new();
let tool_call_collector = ToolCallCollector::new();
let mut timeline = Timeline::new();
let (cancel_tx, cancel_rx) = mpsc::channel(1);
timeline.on_text_block(text_block_collector.clone());
timeline.on_tool_use_block(tool_call_collector.clone());
Self {
client,
timeline,
text_block_collector,
tool_call_collector,
tools: HashMap::new(),
hooks: HookRegistry::new(),
system_prompt: None,
history: Vec::new(),
locked_prefix_len: 0,
turn_count: 0,
turn_notifiers: Vec::new(),
request_config: RequestConfig::default(),
last_run_interrupted: false,
cancel_tx,
cancel_rx,
_state: PhantomData,
}
}
pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.system_prompt = Some(prompt.into());
self
}
pub fn set_system_prompt(&mut self, prompt: impl Into<String>) {
self.system_prompt = Some(prompt.into());
}
pub fn max_tokens(mut self, max_tokens: u32) -> Self {
self.request_config.max_tokens = Some(max_tokens);
self
}
pub fn temperature(mut self, temperature: f32) -> Self {
self.request_config.temperature = Some(temperature);
self
}
pub fn top_p(mut self, top_p: f32) -> Self {
self.request_config.top_p = Some(top_p);
self
}
pub fn top_k(mut self, top_k: u32) -> Self {
self.request_config.top_k = Some(top_k);
self
}
pub fn stop_sequence(mut self, sequence: impl Into<String>) -> Self {
self.request_config.stop_sequences.push(sequence.into());
self
}
pub fn with_config(mut self, config: RequestConfig) -> Self {
self.request_config = config;
self
}
pub fn validate(self) -> Result<Self, WorkerError> {
let warnings = self.client.validate_config(&self.request_config);
if warnings.is_empty() {
Ok(self)
} else {
Err(WorkerError::ConfigWarnings(warnings))
}
}
pub fn history_mut(&mut self) -> &mut Vec<Message> {
&mut self.history
}
pub fn set_history(&mut self, messages: Vec<Message>) {
self.history = messages;
}
pub fn with_message(mut self, message: Message) -> Self {
self.history.push(message);
self
}
pub fn push_message(&mut self, message: Message) {
self.history.push(message);
}
pub fn with_messages(mut self, messages: impl IntoIterator<Item = Message>) -> Self {
self.history.extend(messages);
self
}
pub fn extend_history(&mut self, messages: impl IntoIterator<Item = Message>) {
self.history.extend(messages);
}
pub fn clear_history(&mut self) {
self.history.clear();
}
#[allow(dead_code)]
pub fn config(self, _config: WorkerConfig) -> Self {
self
}
pub fn lock(self) -> Worker<C, CacheLocked> {
let locked_prefix_len = self.history.len();
Worker {
client: self.client,
timeline: self.timeline,
text_block_collector: self.text_block_collector,
tool_call_collector: self.tool_call_collector,
tools: self.tools,
hooks: self.hooks,
system_prompt: self.system_prompt,
history: self.history,
locked_prefix_len,
turn_count: self.turn_count,
turn_notifiers: self.turn_notifiers,
request_config: self.request_config,
last_run_interrupted: self.last_run_interrupted,
cancel_tx: self.cancel_tx,
cancel_rx: self.cancel_rx,
_state: PhantomData,
}
}
}
impl<C: LlmClient> Worker<C, CacheLocked> {
pub fn locked_prefix_len(&self) -> usize {
self.locked_prefix_len
}
pub fn unlock(self) -> Worker<C, Mutable> {
Worker {
client: self.client,
timeline: self.timeline,
text_block_collector: self.text_block_collector,
tool_call_collector: self.tool_call_collector,
tools: self.tools,
hooks: self.hooks,
system_prompt: self.system_prompt,
history: self.history,
locked_prefix_len: 0,
turn_count: self.turn_count,
turn_notifiers: self.turn_notifiers,
request_config: self.request_config,
last_run_interrupted: self.last_run_interrupted,
cancel_tx: self.cancel_tx,
cancel_rx: self.cancel_rx,
_state: PhantomData,
}
}
}
#[cfg(test)]
mod tests {
}