use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
use async_trait::async_trait;
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio::sync::{broadcast, Notify};
use crate::backends::gemini::tools::{register_builtins, BuiltinDeps, FINISH_TOOL_NAME};
use crate::connections::{Connection, ConnectionStrategy, StepStream};
use crate::content::{Content, Part};
use crate::error::{Error, Result};
use crate::hooks::{HookRunner, SessionContext};
use crate::tools::ToolRunner;
use crate::types::{
CapabilitiesConfig, Step, StepStatus, SystemInstructions, ToolCall, ToolResult,
TranscriptEntry, TranscriptRole,
};
use super::gemma::{GemmaConfig, GemmaModel};
use super::generate::generate;
use super::tokenizer::{self, GemmaTokenizer};
use super::tool_parse::parse_tool_code;
use super::weights;
use super::LocalBackend;
const STEP_BROADCAST_CAPACITY: usize = 256;
pub const WEIGHTS_PATH: &str = ".lh_local_model.safetensors";
pub const TOKENIZER_PATH: &str = ".lh_local_tokenizer.json";
const MAX_NEW_TOKENS: usize = 256;
const MAX_TOOL_ROUNDS: u32 = 5;
#[derive(Clone)]
pub struct LocalBackendConfig {
pub model: String,
pub system_instructions: Option<SystemInstructions>,
pub capabilities: CapabilitiesConfig,
pub conversation_id: Option<String>,
pub filesystem: Option<crate::filesystem::SharedFilesystem>,
pub weights_path: String,
pub tokenizer_path: String,
}
impl LocalBackendConfig {
pub fn new(model: impl Into<String>) -> Self {
Self {
model: model.into(),
system_instructions: None,
capabilities: CapabilitiesConfig::default(),
conversation_id: None,
filesystem: None,
weights_path: WEIGHTS_PATH.to_string(),
tokenizer_path: TOKENIZER_PATH.to_string(),
}
}
pub fn with_filesystem(mut self, fs: crate::filesystem::SharedFilesystem) -> Self {
self.filesystem = Some(fs);
self
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self
}
pub fn with_system_instructions(mut self, s: impl Into<SystemInstructions>) -> Self {
self.system_instructions = Some(s.into());
self
}
pub fn with_capabilities(mut self, c: CapabilitiesConfig) -> Self {
self.capabilities = c;
self
}
}
struct Engine {
model: GemmaModel<LocalBackend>,
tokenizer: GemmaTokenizer,
device: burn::backend::wgpu::WgpuDevice,
}
impl Engine {
async fn run(&self, prompt: &str) -> String {
generate(&self.model, &self.tokenizer, prompt, MAX_NEW_TOKENS, &self.device).await
}
}
async fn load_engine(config: &LocalBackendConfig) -> Result<Option<Engine>> {
let Some(fs) = config.filesystem.as_ref() else {
return Ok(None);
};
let weights = match fs.read(&config.weights_path).await {
Ok(b) if !b.is_empty() => b,
_ => return Ok(None),
};
let tok_bytes = match fs.read(&config.tokenizer_path).await {
Ok(b) if !b.is_empty() => b,
_ => return Ok(None),
};
let device = burn::backend::wgpu::WgpuDevice::default();
let cfg = GemmaConfig::gemma_3_270m();
let model = GemmaModel::<LocalBackend>::init(cfg, &device);
let model = weights::load_gemma(model, &weights, &device)
.map_err(|e| Error::other(format!("local model load: {e}")))?;
let tokenizer = tokenizer::load(&tok_bytes)
.map_err(|e| Error::other(format!("local tokenizer load: {e}")))?;
Ok(Some(Engine {
model,
tokenizer,
device,
}))
}
#[derive(Clone, Serialize, Deserialize)]
struct Turn {
role: TurnRole,
text: String,
}
#[derive(Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
enum TurnRole {
User,
Model,
}
struct LoopState {
history: Mutex<Vec<Turn>>,
idle: AtomicBool,
idle_notify: Notify,
cancel: AtomicBool,
steps: broadcast::Sender<Step>,
next_step_index: AtomicU32,
system: Option<String>,
}
impl LoopState {
fn new(steps: broadcast::Sender<Step>, system: Option<String>) -> Self {
Self {
history: Mutex::new(Vec::new()),
idle: AtomicBool::new(true),
idle_notify: Notify::new(),
cancel: AtomicBool::new(false),
steps,
next_step_index: AtomicU32::new(0),
system,
}
}
fn alloc_step_index(&self) -> u32 {
self.next_step_index.fetch_add(1, Ordering::Relaxed)
}
fn emit(&self, step: Step) {
let _ = self.steps.send(step);
}
fn render_prompt_with_tools(&self, runner: Option<&ToolRunner>) -> String {
let mut buf = String::new();
if let Some(sys) = &self.system {
if !sys.is_empty() {
buf.push_str(sys);
buf.push_str("\n\n");
}
}
if let Some(r) = runner {
let tools = r.iter_tools();
if !tools.is_empty() {
buf.push_str(
"You can call functions. To call one, output EXACTLY a fenced block:\n\
```tool_code\nname(arg=value)\n```\n\
Available functions:\n",
);
for t in &tools {
buf.push_str("- ");
buf.push_str(t.name());
buf.push_str(": ");
buf.push_str(t.description());
buf.push('\n');
}
buf.push('\n');
}
}
let hist = self.history.lock();
for t in hist.iter() {
match t.role {
TurnRole::User => buf.push_str("User: "),
TurnRole::Model => buf.push_str("Assistant: "),
}
buf.push_str(&t.text);
buf.push('\n');
}
buf.push_str("Assistant: ");
buf
}
fn emit_tool_call_step(&self, tc: &ToolCall) {
self.emit(Step::tool_call(
self.alloc_step_index(),
tc.clone(),
StepStatus::Active,
));
}
fn emit_tool_result_step(&self, result: &ToolResult) {
self.emit(Step::tool_result(
self.alloc_step_index(),
result.error.clone().unwrap_or_default(),
));
}
}
use crate::backends::render_system;
fn content_to_text(content: Content) -> String {
let mut buf = String::new();
for p in content.parts {
if let Part::Text(t) = p {
buf.push_str(&t);
}
}
buf
}
pub type LocalRunners = crate::backends::BackendRunners;
pub struct LocalConnectionStrategy {
config: LocalBackendConfig,
runners: LocalRunners,
typed_capture: Option<Arc<parking_lot::Mutex<Option<Arc<LocalConnection>>>>>,
}
impl LocalConnectionStrategy {
pub fn new(config: LocalBackendConfig) -> Self {
Self {
config,
runners: LocalRunners::default(),
typed_capture: None,
}
}
pub fn with_runners(mut self, runners: LocalRunners) -> Self {
self.runners = runners;
self
}
pub fn with_typed_capture(
mut self,
slot: Arc<parking_lot::Mutex<Option<Arc<LocalConnection>>>>,
) -> Self {
self.typed_capture = Some(slot);
self
}
}
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
impl ConnectionStrategy for LocalConnectionStrategy {
async fn connect(&self) -> Result<Arc<dyn Connection>> {
let system = self.config.system_instructions.as_ref().map(render_system);
let engine = load_engine(&self.config).await?;
if let Some(runner) = self.runners.tool_runner.as_ref() {
let deps = BuiltinDeps {
chat_client: None,
chat_model: self.config.model.clone(),
image_client: None,
image_model: String::new(),
fs: self.config.filesystem.clone(),
};
let registered = register_builtins(runner, &self.config.capabilities, &deps);
if !registered.is_empty() {
tracing::debug!(?registered, "registered built-in tools (local)");
}
}
let (steps_tx, _) = broadcast::channel::<Step>(STEP_BROADCAST_CAPACITY);
let state = Arc::new(LoopState::new(steps_tx, system));
let conv_id = self
.config
.conversation_id
.clone()
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
let typed = Arc::new(LocalConnection {
state,
engine: engine.map(Arc::new),
conversation_id: conv_id.into(),
tool_runner: self.runners.tool_runner.clone(),
hook_runner: self.runners.hook_runner.clone(),
session_ctx: self.runners.session_ctx.clone(),
});
if let Some(slot) = &self.typed_capture {
*slot.lock() = Some(typed.clone());
}
Ok(typed)
}
}
pub struct LocalConnection {
state: Arc<LoopState>,
engine: Option<Arc<Engine>>,
conversation_id: Arc<str>,
tool_runner: Option<Arc<ToolRunner>>,
hook_runner: Option<Arc<HookRunner>>,
session_ctx: Option<SessionContext>,
}
impl LocalConnection {
pub fn history_bytes(&self) -> Result<Vec<u8>> {
let snapshot = self.state.history.lock().clone();
serde_json::to_vec(&snapshot).map_err(|e| Error::other(format!("history_bytes: {e}")))
}
pub fn set_history_bytes(&self, bytes: &[u8]) -> Result<()> {
if bytes.is_empty() {
return Ok(());
}
let restored: Vec<Turn> = serde_json::from_slice(bytes)
.map_err(|e| Error::other(format!("set_history_bytes: {e}")))?;
*self.state.history.lock() = restored;
Ok(())
}
pub async fn compact(&self) -> bool {
false
}
pub fn clear_history(&self) {
self.state.history.lock().clear();
self.state.next_step_index.store(0, Ordering::Relaxed);
}
pub fn transcript(&self) -> Vec<TranscriptEntry> {
self.state
.history
.lock()
.iter()
.map(|t| TranscriptEntry {
role: match t.role {
TurnRole::User => TranscriptRole::User,
TurnRole::Model => TranscriptRole::Assistant,
},
text: t.text.clone(),
tool_calls: Vec::new(),
})
.collect()
}
pub fn is_model_loaded(&self) -> bool {
self.engine.is_some()
}
}
pub fn decode_transcript_bytes(bytes: &[u8]) -> Result<Vec<TranscriptEntry>> {
if bytes.is_empty() {
return Ok(Vec::new());
}
let history: Vec<Turn> = serde_json::from_slice(bytes)
.map_err(|e| Error::other(format!("decode_transcript_bytes: {e}")))?;
Ok(history
.into_iter()
.map(|t| TranscriptEntry {
role: match t.role {
TurnRole::User => TranscriptRole::User,
TurnRole::Model => TranscriptRole::Assistant,
},
text: t.text,
tool_calls: Vec::new(),
})
.collect())
}
fn terminal_step(state: &LoopState, traj: &str, text: String) -> Step {
Step::turn_complete(
traj,
state.alloc_step_index(),
StepStatus::Done,
text,
"",
None,
None,
)
}
fn error_step(state: &LoopState, message: String) -> Step {
Step::turn_error(state.alloc_step_index(), message)
}
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
impl Connection for LocalConnection {
fn is_idle(&self) -> bool {
self.state.idle.load(Ordering::Acquire)
}
fn conversation_id(&self) -> &str {
&self.conversation_id
}
async fn send(&self, content: Content) -> Result<()> {
let state = self.state.clone();
let engine = self.engine.clone();
let tool_runner = self.tool_runner.clone();
let hook_runner = self.hook_runner.clone();
let traj = uuid::Uuid::new_v4().to_string();
let turn_ctx = self
.session_ctx
.as_ref()
.map(|s| s.child())
.unwrap_or_default();
if let Some(denied) =
crate::backends::dispatch::gate_pre_turn(hook_runner.as_ref(), &turn_ctx, &content)
.await
{
state.emit(error_step(&state, denied));
return Ok(());
}
let prompt = content_to_text(content);
state.idle.store(false, Ordering::Release);
state.cancel.store(false, Ordering::Release);
state.history.lock().push(Turn {
role: TurnRole::User,
text: prompt,
});
crate::runtime::spawn(async move {
let Some(engine) = engine else {
state.emit(error_step(
&state,
"local model not downloaded — open the model tab and download Gemma first"
.to_string(),
));
state.idle.store(true, Ordering::Release);
state.idle_notify.notify_waiters();
return;
};
let mut final_text = String::new();
let mut rounds = 0u32;
loop {
rounds += 1;
if rounds > MAX_TOOL_ROUNDS || state.cancel.load(Ordering::Acquire) {
break;
}
let rendered = state.render_prompt_with_tools(tool_runner.as_deref());
let reply = engine.run(&rendered).await;
let parsed = tool_runner
.as_ref()
.and_then(|_| parse_tool_code(&reply));
let Some((name, args)) = parsed else {
final_text = reply.clone();
state.history.lock().push(Turn {
role: TurnRole::Model,
text: reply,
});
break;
};
state.history.lock().push(Turn {
role: TurnRole::Model,
text: reply.clone(),
});
if name == FINISH_TOOL_NAME {
final_text = reply;
break;
}
let tool_call = ToolCall {
name: name.clone(),
args,
id: None,
canonical_path: None,
};
state.emit_tool_call_step(&tool_call);
let post_result = crate::backends::dispatch::dispatch_tool_call(
tool_runner.as_ref(),
hook_runner.as_ref(),
&turn_ctx,
&tool_call,
)
.await;
state.emit_tool_result_step(&post_result);
let result_value = post_result.result.unwrap_or(Value::Null);
let out = serde_json::to_string(&result_value).unwrap_or_default();
state.history.lock().push(Turn {
role: TurnRole::User,
text: format!("```tool_output\n{out}\n```"),
});
}
state.emit(terminal_step(&state, &traj, final_text.clone()));
crate::backends::dispatch::dispatch_post_turn(
hook_runner.as_ref(),
&turn_ctx,
&final_text,
)
.await;
state.idle.store(true, Ordering::Release);
state.idle_notify.notify_waiters();
});
Ok(())
}
async fn send_trigger(&self, content: String) -> Result<()> {
self.send(Content::text(content)).await
}
async fn send_tool_results(&self, _results: Vec<ToolResult>) -> Result<()> {
Ok(())
}
fn subscribe_steps(&self) -> StepStream {
crate::backends::subscribe_step_stream(self.state.steps.subscribe(), "local")
}
async fn wait_for_idle(&self) -> Result<()> {
loop {
if self.is_idle() {
return Ok(());
}
self.state.idle_notify.notified().await;
}
}
fn cancel_turn(&self) {
self.state.cancel.store(true, Ordering::Release);
}
async fn shutdown(&self) -> Result<()> {
self.state.idle.store(true, Ordering::Release);
self.state.idle_notify.notify_waiters();
Ok(())
}
}
#[cfg(all(test, not(target_arch = "wasm32")))]
mod tests {
use super::*;
use crate::conversation::Conversation;
use serde_json::json;
#[tokio::test]
async fn send_without_weights_errors_clearly() {
let cfg = LocalBackendConfig::new("gemma-3-270m"); let conn = LocalConnectionStrategy::new(cfg)
.connect()
.await
.expect("connect succeeds even without weights");
let conv = Conversation::new(conn);
let resp = conv.chat("hi").await.expect("send dispatches");
match resp.text().await {
Ok(t) => panic!("expected a 'not downloaded' error, got: {t:?}"),
Err(e) => assert!(
e.to_string().contains("not downloaded"),
"expected the not-downloaded message, got: {e}"
),
}
}
#[test]
fn history_round_trips() {
let (tx, _) = broadcast::channel::<Step>(STEP_BROADCAST_CAPACITY);
let state = Arc::new(LoopState::new(tx, None));
state.history.lock().push(Turn {
role: TurnRole::User,
text: "hello".into(),
});
state.history.lock().push(Turn {
role: TurnRole::Model,
text: "hi there".into(),
});
let conn = LocalConnection {
state,
engine: None,
conversation_id: "test".into(),
tool_runner: None,
hook_runner: None,
session_ctx: None,
};
let bytes = conn.history_bytes().unwrap();
let entries = decode_transcript_bytes(&bytes).unwrap();
assert_eq!(entries.len(), 2);
assert_eq!(entries[0].role, TranscriptRole::User);
assert_eq!(entries[1].text, "hi there");
}
#[tokio::test]
async fn fenced_call_dispatches_through_runner() {
use crate::tools::{ClosureTool, ToolRunner};
use std::sync::Arc;
let runner = ToolRunner::new();
runner.register(ClosureTool::new(
"echo",
"echo back the message",
json!({
"type": "object",
"properties": { "msg": { "type": "string" } }
}),
|args, _ctx| async move {
let msg = args["msg"].as_str().unwrap_or("").to_string();
Ok(json!({ "echoed": msg }))
},
));
let runner = Arc::new(runner);
let reply = "ok\n```tool_code\necho(msg=\"hi there\")\n```";
let (name, args) = parse_tool_code(reply).expect("a parsed call");
assert_eq!(name, "echo");
let out = runner.execute(&name, args).await.expect("execute");
assert_eq!(out["echoed"], "hi there");
}
#[test]
fn prompt_lists_registered_tools() {
use crate::tools::{ClosureTool, ToolRunner};
let (tx, _) = broadcast::channel::<Step>(STEP_BROADCAST_CAPACITY);
let state = LoopState::new(tx, None);
state.history.lock().push(Turn {
role: TurnRole::User,
text: "read it".into(),
});
let runner = ToolRunner::new();
runner.register(ClosureTool::new(
"view_file",
"read a file",
json!({ "type": "object" }),
|_a, _c| async move { Ok(json!({})) },
));
let with = state.render_prompt_with_tools(Some(&runner));
assert!(with.contains("```tool_code"));
assert!(with.contains("view_file: read a file"));
let without = state.render_prompt_with_tools(None);
assert!(!without.contains("```tool_code"));
assert!(without.contains("User: read it"));
}
#[tokio::test]
#[ignore]
async fn gemma_native_forward() {
let dir = std::env::var("GEMMA_DIR")
.expect("set GEMMA_DIR to a folder with model.safetensors + tokenizer.json");
let weights = std::fs::read(format!("{dir}/model.safetensors")).expect("read weights");
let tok_bytes = std::fs::read(format!("{dir}/tokenizer.json")).expect("read tokenizer.json");
let device = burn::backend::wgpu::WgpuDevice::default();
let model = super::super::gemma::GemmaModel::<super::super::LocalBackend>::init(
super::super::gemma::GemmaConfig::gemma_3_270m(),
&device,
);
let model =
super::super::weights::load_gemma(model, &weights, &device).expect("load_gemma");
let tok = super::super::tokenizer::GemmaTokenizer::from_bytes(&tok_bytes)
.expect("load tokenizer");
let prompt = "The capital of France is";
let out = super::super::generate::generate(&model, &tok, prompt, 16, &device).await;
println!("\n=== GEMMA NATIVE FORWARD ===\nprompt: {prompt:?}\noutput: {out:?}\n============================\n");
assert!(
!out.trim().is_empty(),
"empty continuation — immediate EOS or a loader bug"
);
}
}