use std::collections::HashSet;
use std::convert::Infallible;
use std::path::{Path, PathBuf};
use std::pin::Pin;
use std::time::Duration;
use async_stream::stream;
use async_trait::async_trait;
use bytes::Bytes;
use futures_core::stream::Stream;
use http::header::{CACHE_CONTROL, CONTENT_TYPE, LOCATION, SET_COOKIE};
use http::{Request, Response, StatusCode};
use http_body_util::BodyExt;
use http_body_util::Full;
use hyper::client::conn::http1;
use ranvier::core::{Bus, Outcome, Transition};
use ranvier::http::{
BusHttpExt, CookieJar, GuardExec, GuardIntegration, GuardRejection, HttpResponse, IntoResponse,
RegisteredGuard, Sse, SseEvent, json_error_response,
};
use ranvier::runtime::Axon;
use serde::{Deserialize, Serialize};
use serde_json::json;
use soma_studio_core::{
ApiErrorResponse, AppInitResponse, ChatRetrievalSource, ChatSendRequest, ChatSendResponse,
ChatStreamEvent, ConversationCreateRequest, ConversationDeleteResponse, ConversationMessage,
ConversationSummary, HealthResponse, IngestJobSummary, IngestRescanRequest,
IngestStatusResponse, NotebookEmbeddingResponse, NotebookNoteCreateRequest,
NotebookNoteWriteRequest, NotebookRenderRequest, ProviderModelSummary,
ProviderSelectionResponse, ProviderSelectionUpdateRequest, ProviderSummary,
ProviderTestRequest, ProviderTestResponse, SearchFieldScope, SearchIndexStatusResponse,
SearchOpenActionResponse, SearchProfile, SearchResponse, SearchSort, SearchSourceType,
SourceRootCreateInput, SourceRootSummary, WorkspaceFileChangeApplyRequest,
WorkspaceFileChangeAuditEntry, WorkspaceFileChangeAuditStatus,
WorkspaceFileChangePreviewRequest, WorkspaceSourceRootCreateRequest, WorkspaceTaskRunRequest,
WorkspaceTaskRunSummary,
};
use tokio::time::interval;
use tokio::{net::TcpStream, time::timeout};
use tracing::warn;
use url::Url;
use uuid::Uuid;
use crate::app::AppContext;
use crate::chat_events::{
chat_event_matches_subscription, chat_stream_event_envelope, chat_stream_event_payload,
};
use crate::search_index::SearchIndexQueryOptions;
use crate::storage::NewConversationMessage;
use crate::workspace::WorkspaceError;
use crate::workspace_tasks::WorkspaceTaskError;
type ChatStream = Pin<Box<dyn Stream<Item = Result<SseEvent, Infallible>> + Send + Sync>>;
const SESSION_COOKIE_MAX_AGE_SECONDS: u64 = 60 * 60 * 24 * 30;
#[derive(Debug, Clone)]
pub struct RequestOrigin(pub Option<String>);
#[derive(Debug, Clone)]
pub struct RequestHost(pub Option<String>);
#[derive(Debug, Clone)]
pub struct RequestPath(pub String);
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RedirectWithCookie {
pub location: String,
pub set_cookie: String,
}
impl IntoResponse for RedirectWithCookie {
fn into_response(self) -> HttpResponse {
build_plain_response(
StatusCode::FOUND,
String::new(),
vec![
(LOCATION.as_str().to_string(), self.location),
(SET_COOKIE.as_str().to_string(), self.set_cookie),
(CACHE_CONTROL.as_str().to_string(), "no-store".to_string()),
],
)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NotebookWireResponse {
pub status: u16,
pub content_type: String,
pub body: Vec<u8>,
}
impl IntoResponse for NotebookWireResponse {
fn into_response(self) -> HttpResponse {
Response::builder()
.status(StatusCode::from_u16(self.status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR))
.header(CONTENT_TYPE, self.content_type)
.header(CACHE_CONTROL, "no-store")
.body(
Full::new(Bytes::from(self.body))
.map_err(|never| match never {})
.boxed(),
)
.expect("notebook response should be buildable")
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonWireResponse {
pub status: u16,
pub body: Vec<u8>,
}
impl IntoResponse for JsonWireResponse {
fn into_response(self) -> HttpResponse {
Response::builder()
.status(StatusCode::from_u16(self.status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR))
.header(CONTENT_TYPE, "application/json")
.header(CACHE_CONTROL, "no-store")
.body(
Full::new(Bytes::from(self.body))
.map_err(|never| match never {})
.boxed(),
)
.expect("json response should be buildable")
}
}
impl From<NotebookArtifactFile> for NotebookWireResponse {
fn from(value: NotebookArtifactFile) -> Self {
Self {
status: StatusCode::OK.as_u16(),
content_type: value.content_type,
body: value.bytes,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NotebookArtifactFile {
pub bytes: Vec<u8>,
pub content_type: String,
}
impl IntoResponse for NotebookArtifactFile {
fn into_response(self) -> HttpResponse {
Response::builder()
.status(StatusCode::OK)
.header(CONTENT_TYPE, self.content_type)
.header(CACHE_CONTROL, "no-store")
.body(
Full::new(Bytes::from(self.bytes))
.map_err(|never| match never {})
.boxed(),
)
.expect("artifact response should be buildable")
}
}
#[derive(Clone, Copy)]
struct BootstrapRedirect;
#[derive(Clone, Copy)]
pub struct RequireSessionGuard;
struct RequireSessionGuardExec;
#[async_trait]
impl GuardExec for RequireSessionGuardExec {
async fn exec_guard(&self, bus: &mut Bus) -> Result<(), GuardRejection> {
let app = bus.get_cloned::<AppContext>().map_err(|error| {
GuardRejection::new(StatusCode::INTERNAL_SERVER_ERROR, error.to_string())
})?;
require_session(bus, &app)
.await
.map(|_| ())
.map_err(GuardRejection::unauthorized)
}
}
impl GuardIntegration for RequireSessionGuard {
fn register(self) -> RegisteredGuard {
RegisteredGuard {
bus_injectors: Vec::new(),
response_extractor: None,
response_body_transform: None,
exec: std::sync::Arc::new(RequireSessionGuardExec),
handles_preflight: false,
preflight_config: None,
}
}
}
#[derive(Clone, Copy)]
pub struct RequireSameOriginGuard;
struct RequireSameOriginGuardExec;
#[async_trait]
impl GuardExec for RequireSameOriginGuardExec {
async fn exec_guard(&self, bus: &mut Bus) -> Result<(), GuardRejection> {
validate_same_origin(bus).map_err(GuardRejection::forbidden)
}
}
impl GuardIntegration for RequireSameOriginGuard {
fn register(self) -> RegisteredGuard {
RegisteredGuard {
bus_injectors: Vec::new(),
response_extractor: None,
response_body_transform: None,
exec: std::sync::Arc::new(RequireSameOriginGuardExec),
handles_preflight: false,
preflight_config: None,
}
}
}
#[async_trait]
impl Transition<(), RedirectWithCookie> for BootstrapRedirect {
type Error = String;
type Resources = ();
async fn run(
&self,
_state: (),
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<RedirectWithCookie, Self::Error> {
let app = match bus.get_cloned::<AppContext>() {
Ok(app) => app,
Err(error) => return Outcome::Fault(format!("missing app context: {error}")),
};
let Some(token) = bus.query_param::<String>("token") else {
return Outcome::Fault("missing bootstrap token query parameter".to_string());
};
if !app.consume_bootstrap_token(&token) {
return Outcome::Fault("invalid bootstrap token".to_string());
}
let existing_session_id = bus
.get_cloned::<CookieJar>()
.ok()
.and_then(|jar| jar.get("soma_studio_session").map(str::to_owned));
let session = if let Some(session_id) = existing_session_id {
match app.storage.session_exists(&session_id).await {
Ok(true) => match app.remember_session(session_id) {
Some(session) => session,
None => {
return Outcome::Fault(
"persisted session id failed validation during bootstrap".to_string(),
);
}
},
Ok(false) => match app.issue_session().await {
Ok(session) => session,
Err(error) => {
return Outcome::Fault(format!("failed to issue session: {error}"));
}
},
Err(error) => {
return Outcome::Fault(format!(
"failed to restore persisted session during bootstrap: {error}"
));
}
}
} else {
match app.issue_session().await {
Ok(session) => session,
Err(error) => {
return Outcome::Fault(format!("failed to issue session: {error}"));
}
}
};
Outcome::Next(RedirectWithCookie {
location: "/".to_string(),
set_cookie: format!(
"soma_studio_session={}; Path=/; HttpOnly; SameSite=Strict; Max-Age={}",
session.id, SESSION_COOKIE_MAX_AGE_SECONDS
),
})
}
}
#[derive(Clone, Copy)]
struct Health;
#[async_trait]
impl Transition<(), HealthResponse> for Health {
type Error = String;
type Resources = ();
async fn run(
&self,
_state: (),
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<HealthResponse, Self::Error> {
let app = match bus.get_cloned::<AppContext>() {
Ok(app) => app,
Err(error) => return Outcome::Fault(format!("missing app context: {error}")),
};
Outcome::Next(HealthResponse {
ok: true,
app_name: app.config.app_name.clone(),
bind_addr: app.config.bind_addr.clone(),
})
}
}
#[derive(Clone, Copy)]
struct Init;
#[async_trait]
impl Transition<(), AppInitResponse> for Init {
type Error = String;
type Resources = ();
async fn run(
&self,
_state: (),
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<AppInitResponse, Self::Error> {
let app = match bus.get_cloned::<AppContext>() {
Ok(app) => app,
Err(error) => return Outcome::Fault(format!("missing app context: {error}")),
};
let session_id = session_id_from_bus(bus, &app).await;
let provider_statuses = match app.storage.list_provider_statuses().await {
Ok(statuses) => statuses,
Err(error) => {
return Outcome::Fault(format!("failed to load provider statuses: {error}"));
}
};
let selection = match app.storage.load_provider_selection().await {
Ok(selection) => selection,
Err(error) => {
return Outcome::Fault(format!("failed to load provider selection: {error}"));
}
};
Outcome::Next(AppInitResponse {
app_name: app.config.app_name.clone(),
authenticated: session_id.is_some(),
session_id,
providers: merge_provider_summaries(app.provider_summaries(), provider_statuses),
selected_provider: selection.selected_provider,
selected_model_id: selection.selected_model_id,
routes: vec![
"/api/health".to_string(),
"/api/providers".to_string(),
"/api/providers/models".to_string(),
"/api/providers/selection".to_string(),
"/api/source-roots".to_string(),
"/api/workspace/files".to_string(),
"/api/workspace/file-preview".to_string(),
"/api/workspace/file-change-preview".to_string(),
"/api/workspace/file-change-apply".to_string(),
"/api/workspace/file-change-audits".to_string(),
"/api/workspace/source-root".to_string(),
"/api/workspace/task".to_string(),
"/api/workspace/task-runs".to_string(),
"/api/workspace/task-runs/:id".to_string(),
"/api/workspace/task-runs/:id/cancel".to_string(),
"/api/ingest/jobs".to_string(),
"/api/ingest/status".to_string(),
"/api/conversations".to_string(),
"/api/conversations/:id/messages".to_string(),
"/api/notebook/tree".to_string(),
"/api/notebook/note".to_string(),
"/api/notebook/search".to_string(),
"/api/notebook/adapters".to_string(),
"/api/notebook/render".to_string(),
"/api/notebook/index".to_string(),
"/api/notebook/chunks".to_string(),
"/api/notebook/embeddings".to_string(),
"/api/notebook/retrieve".to_string(),
"/api/chat/send".to_string(),
"/api/stream/chat".to_string(),
],
})
}
}
#[derive(Clone, Copy)]
struct Providers;
#[async_trait]
impl Transition<(), Vec<ProviderSummary>> for Providers {
type Error = String;
type Resources = ();
async fn run(
&self,
_state: (),
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<Vec<ProviderSummary>, Self::Error> {
let app = match bus.get_cloned::<AppContext>() {
Ok(app) => app,
Err(error) => return Outcome::Fault(format!("missing app context: {error}")),
};
let provider_statuses = match app.storage.list_provider_statuses().await {
Ok(statuses) => statuses,
Err(error) => {
return Outcome::Fault(format!("failed to load provider statuses: {error}"));
}
};
Outcome::Next(merge_provider_summaries(
app.provider_summaries(),
provider_statuses,
))
}
}
#[derive(Clone, Copy)]
struct ProviderModels;
#[async_trait]
impl Transition<(), serde_json::Value> for ProviderModels {
type Error = String;
type Resources = ();
async fn run(
&self,
_state: (),
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<serde_json::Value, Self::Error> {
let app = match bus.get_cloned::<AppContext>() {
Ok(app) => app,
Err(error) => return Outcome::Fault(format!("missing app context: {error}")),
};
if let Err(error) = require_session(bus, &app).await {
return Outcome::Fault(error);
}
let provider = match bus.query_param::<String>("provider") {
Some(provider) => normalize_provider_id(&provider),
None => return Outcome::Fault("missing provider query parameter".to_string()),
};
let fetch_result = match provider.as_str() {
"ollama" => fetch_ollama_models().await,
"lmstudio" => fetch_lmstudio_models().await,
other => Err(format!("unsupported provider: {other}")),
};
match fetch_result {
Ok(models) => Outcome::Next(json!(models)),
Err(error) => Outcome::Fault(error),
}
}
}
#[derive(Clone, Copy)]
struct ListSourceRoots;
#[async_trait]
impl Transition<(), Vec<SourceRootSummary>> for ListSourceRoots {
type Error = String;
type Resources = ();
async fn run(
&self,
_state: (),
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<Vec<SourceRootSummary>, Self::Error> {
let app = match bus.get_cloned::<AppContext>() {
Ok(app) => app,
Err(error) => return Outcome::Fault(format!("missing app context: {error}")),
};
if let Err(error) = require_session(bus, &app).await {
return Outcome::Fault(error);
}
let roots = app
.source_roots
.read()
.expect("source root store poisoned")
.clone();
Outcome::Next(roots)
}
}
#[derive(Clone, Copy)]
struct CreateSourceRoot;
#[async_trait]
impl Transition<SourceRootCreateInput, NotebookWireResponse> for CreateSourceRoot {
type Error = String;
type Resources = ();
async fn run(
&self,
input: SourceRootCreateInput,
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<NotebookWireResponse, Self::Error> {
let app = match bus.get_cloned::<AppContext>() {
Ok(app) => app,
Err(error) => return Outcome::Fault(format!("missing app context: {error}")),
};
let path = match std::fs::canonicalize(PathBuf::from(&input.path)) {
Ok(path) => path,
Err(error) => {
return Outcome::Next(notebook_error_response(&format!(
"failed to resolve source root '{}': {error}",
input.path
)));
}
};
if !path.is_dir() {
return Outcome::Next(notebook_error_response(&format!(
"source root must be a directory: {}",
path.display()
)));
}
let summary = SourceRootSummary {
id: Uuid::new_v4(),
path: path.to_string_lossy().to_string(),
read_only: true,
};
let persisted = match app.storage.upsert_source_root(&summary).await {
Ok(summary) => summary,
Err(error) => return Outcome::Fault(format!("failed to persist source root: {error}")),
};
Outcome::Next(notebook_json_response(
StatusCode::OK,
&app.register_source_root(persisted),
))
}
}
#[derive(Clone, Copy)]
struct RescanIngest;
#[async_trait]
impl Transition<IngestRescanRequest, NotebookWireResponse> for RescanIngest {
type Error = String;
type Resources = ();
async fn run(
&self,
input: IngestRescanRequest,
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<NotebookWireResponse, Self::Error> {
let app = match bus.get_cloned::<AppContext>() {
Ok(app) => app,
Err(error) => return Outcome::Fault(format!("missing app context: {error}")),
};
if let Err(error) = require_session(bus, &app).await {
return Outcome::Fault(error);
}
let roots = app
.source_roots
.read()
.expect("source root store poisoned")
.clone();
let selected_roots = match input.source_root_id.as_deref() {
Some(root_id) => roots
.into_iter()
.filter(|root| root.id.to_string() == root_id)
.collect::<Vec<_>>(),
None => roots,
};
if selected_roots.is_empty() {
return Outcome::Next(notebook_error_response(
"source root를 먼저 등록해야 ingest rescan을 시작할 수 있습니다.",
));
}
let scope_label = input
.source_root_id
.clone()
.unwrap_or_else(|| "all".to_string());
let started = match app
.storage
.create_ingest_job(input.source_root_id.as_deref(), &scope_label)
.await
{
Ok(job) => job,
Err(error) => return Outcome::Fault(format!("failed to create ingest job: {error}")),
};
let scan_result = tokio::task::spawn_blocking(move || {
let mut all = Vec::new();
for root in &selected_roots {
let files = crate::ingest::scan_source_root(root)?;
all.push((root.id.to_string(), files));
}
Ok::<_, String>(all)
})
.await;
let scanned = match scan_result {
Ok(Ok(scanned)) => scanned,
Ok(Err(error)) => {
let _ = app
.storage
.complete_ingest_job(&started.id, "error", 0, Some(&error))
.await;
return Outcome::Next(notebook_error_response(&error));
}
Err(error) => {
let detail = format!("ingest worker join failed: {error}");
let _ = app
.storage
.complete_ingest_job(&started.id, "error", 0, Some(&detail))
.await;
return Outcome::Next(notebook_error_response(&detail));
}
};
let mut total_files = 0usize;
for (root_id, files) in &scanned {
let materialized = match crate::ingest::materialize_text_artifacts(&app.config, files) {
Ok(files) => files,
Err(error) => {
let detail =
format!("failed to materialize source root text artifacts: {error}");
let _ = app
.storage
.complete_ingest_job(&started.id, "error", total_files, Some(&detail))
.await;
return Outcome::Next(notebook_error_response(&detail));
}
};
total_files += materialized.len();
if let Err(error) = app
.storage
.replace_source_files(root_id, &materialized)
.await
{
let detail = format!("failed to persist source files: {error}");
let _ = app
.storage
.complete_ingest_job(&started.id, "error", total_files, Some(&detail))
.await;
return Outcome::Next(notebook_error_response(&detail));
}
}
match app
.storage
.complete_ingest_job(&started.id, "complete", total_files, None)
.await
{
Ok(job) => Outcome::Next(notebook_json_response(StatusCode::OK, &job)),
Err(error) => Outcome::Fault(format!("failed to complete ingest job: {error}")),
}
}
}
#[derive(Clone, Copy)]
struct ListIngestJobs;
#[async_trait]
impl Transition<(), Vec<IngestJobSummary>> for ListIngestJobs {
type Error = String;
type Resources = ();
async fn run(
&self,
_state: (),
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<Vec<IngestJobSummary>, Self::Error> {
let app = match bus.get_cloned::<AppContext>() {
Ok(app) => app,
Err(error) => return Outcome::Fault(format!("missing app context: {error}")),
};
if let Err(error) = require_session(bus, &app).await {
return Outcome::Fault(error);
}
match app.storage.list_ingest_jobs().await {
Ok(jobs) => Outcome::Next(jobs),
Err(error) => Outcome::Fault(format!("failed to list ingest jobs: {error}")),
}
}
}
#[derive(Clone, Copy)]
struct GetIngestStatus;
#[async_trait]
impl Transition<(), IngestStatusResponse> for GetIngestStatus {
type Error = String;
type Resources = ();
async fn run(
&self,
_state: (),
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<IngestStatusResponse, Self::Error> {
let app = match bus.get_cloned::<AppContext>() {
Ok(app) => app,
Err(error) => return Outcome::Fault(format!("missing app context: {error}")),
};
if let Err(error) = require_session(bus, &app).await {
return Outcome::Fault(error);
}
match app.storage.load_ingest_status().await {
Ok(status) => Outcome::Next(status),
Err(error) => Outcome::Fault(format!("failed to load ingest status: {error}")),
}
}
}
#[derive(Clone, Copy)]
struct ListConversations;
#[async_trait]
impl Transition<(), Vec<ConversationSummary>> for ListConversations {
type Error = String;
type Resources = ();
async fn run(
&self,
_state: (),
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<Vec<ConversationSummary>, Self::Error> {
let app = match bus.get_cloned::<AppContext>() {
Ok(app) => app,
Err(error) => return Outcome::Fault(format!("missing app context: {error}")),
};
let session_id = match require_session(bus, &app).await {
Ok(session_id) => session_id,
Err(error) => return Outcome::Fault(error),
};
match app.storage.list_conversations(&session_id).await {
Ok(conversations) => Outcome::Next(conversations),
Err(error) => Outcome::Fault(format!("failed to list conversations: {error}")),
}
}
}
#[derive(Clone, Copy)]
struct CreateConversation;
#[async_trait]
impl Transition<ConversationCreateRequest, ConversationSummary> for CreateConversation {
type Error = String;
type Resources = ();
async fn run(
&self,
input: ConversationCreateRequest,
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<ConversationSummary, Self::Error> {
let app = match bus.get_cloned::<AppContext>() {
Ok(app) => app,
Err(error) => return Outcome::Fault(format!("missing app context: {error}")),
};
let session_id = match require_session(bus, &app).await {
Ok(session_id) => session_id,
Err(error) => return Outcome::Fault(error),
};
match app
.storage
.create_conversation(&session_id, input.title.as_deref())
.await
{
Ok(conversation) => Outcome::Next(conversation),
Err(error) => Outcome::Fault(format!("failed to create conversation: {error}")),
}
}
}
#[derive(Clone, Copy)]
struct ListConversationMessages;
#[async_trait]
impl Transition<(), Vec<ConversationMessage>> for ListConversationMessages {
type Error = String;
type Resources = ();
async fn run(
&self,
_state: (),
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<Vec<ConversationMessage>, Self::Error> {
let app = match bus.get_cloned::<AppContext>() {
Ok(app) => app,
Err(error) => return Outcome::Fault(format!("missing app context: {error}")),
};
let session_id = match require_session(bus, &app).await {
Ok(session_id) => session_id,
Err(error) => return Outcome::Fault(error),
};
let conversation_id = match conversation_id_from_bus(bus) {
Ok(conversation_id) => conversation_id,
Err(error) => return Outcome::Fault(error),
};
match app
.storage
.list_conversation_messages(&session_id, &conversation_id)
.await
{
Ok(messages) => Outcome::Next(messages),
Err(error) => Outcome::Fault(format!("failed to list conversation messages: {error}")),
}
}
}
#[derive(Clone, Copy)]
struct DeleteConversation;
#[async_trait]
impl Transition<(), ConversationDeleteResponse> for DeleteConversation {
type Error = String;
type Resources = ();
async fn run(
&self,
_state: (),
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<ConversationDeleteResponse, Self::Error> {
let app = match bus.get_cloned::<AppContext>() {
Ok(app) => app,
Err(error) => return Outcome::Fault(format!("missing app context: {error}")),
};
let session_id = match require_session(bus, &app).await {
Ok(session_id) => session_id,
Err(error) => return Outcome::Fault(error),
};
let conversation_id = match conversation_id_from_bus(bus) {
Ok(conversation_id) => conversation_id,
Err(error) => return Outcome::Fault(error),
};
match app
.storage
.delete_conversation(&session_id, &conversation_id)
.await
{
Ok(deleted) => Outcome::Next(ConversationDeleteResponse {
conversation_id,
deleted,
}),
Err(error) => Outcome::Fault(format!("failed to delete conversation: {error}")),
}
}
}
#[derive(Clone, Copy)]
struct ListNotebookNotes;
#[async_trait]
impl Transition<(), NotebookWireResponse> for ListNotebookNotes {
type Error = String;
type Resources = ();
async fn run(
&self,
_state: (),
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<NotebookWireResponse, Self::Error> {
let app = match bus.get_cloned::<AppContext>() {
Ok(app) => app,
Err(error) => return Outcome::Fault(format!("missing app context: {error}")),
};
if let Err(error) = require_session(bus, &app).await {
return Outcome::Fault(error);
}
match crate::notebook::list_notes(&app.config) {
Ok(notes) => Outcome::Next(notebook_json_response(StatusCode::OK, ¬es)),
Err(error) => Outcome::Next(notebook_error_response(&error)),
}
}
}
#[derive(Clone, Copy)]
struct ReadNotebookNote;
#[async_trait]
impl Transition<(), NotebookWireResponse> for ReadNotebookNote {
type Error = String;
type Resources = ();
async fn run(
&self,
_state: (),
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<NotebookWireResponse, Self::Error> {
let app = match bus.get_cloned::<AppContext>() {
Ok(app) => app,
Err(error) => return Outcome::Fault(format!("missing app context: {error}")),
};
if let Err(error) = require_session(bus, &app).await {
return Outcome::Fault(error);
}
let path = match required_decoded_query_param(bus, "path", "missing notebook note path") {
Ok(path) => path,
Err(error) => return Outcome::Next(notebook_error_response(&error)),
};
match crate::notebook::read_note(&app.config, &path) {
Ok(note) => Outcome::Next(notebook_json_response(StatusCode::OK, ¬e)),
Err(error) => Outcome::Next(notebook_error_response(&error)),
}
}
}
#[derive(Clone, Copy)]
struct ReadNotebookArtifact;
#[async_trait]
impl Transition<(), NotebookWireResponse> for ReadNotebookArtifact {
type Error = String;
type Resources = ();
async fn run(
&self,
_state: (),
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<NotebookWireResponse, Self::Error> {
let app = match bus.get_cloned::<AppContext>() {
Ok(app) => app,
Err(error) => return Outcome::Fault(format!("missing app context: {error}")),
};
if let Err(error) = require_session(bus, &app).await {
return Outcome::Fault(error);
}
let path = match required_decoded_query_param(bus, "path", "missing notebook artifact path")
{
Ok(path) => path,
Err(error) => return Outcome::Next(notebook_error_response(&error)),
};
match crate::notebook::read_artifact(&app.config, &path) {
Ok(bytes) => Outcome::Next(
NotebookArtifactFile {
bytes,
content_type: "application/pdf".to_string(),
}
.into(),
),
Err(error) => Outcome::Next(notebook_error_response(&error)),
}
}
}
#[derive(Clone, Copy)]
struct CreateNotebookNote;
#[async_trait]
impl Transition<NotebookNoteCreateRequest, NotebookWireResponse> for CreateNotebookNote {
type Error = String;
type Resources = ();
async fn run(
&self,
input: NotebookNoteCreateRequest,
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<NotebookWireResponse, Self::Error> {
let app = match bus.get_cloned::<AppContext>() {
Ok(app) => app,
Err(error) => return Outcome::Fault(format!("missing app context: {error}")),
};
if let Err(error) = require_session(bus, &app).await {
return Outcome::Fault(error);
}
match crate::notebook::create_note(&app.config, input) {
Ok(note) => Outcome::Next(notebook_json_response(StatusCode::OK, ¬e)),
Err(error) => Outcome::Next(notebook_error_response(&error)),
}
}
}
#[derive(Clone, Copy)]
struct WriteNotebookNote;
#[async_trait]
impl Transition<NotebookNoteWriteRequest, NotebookWireResponse> for WriteNotebookNote {
type Error = String;
type Resources = ();
async fn run(
&self,
input: NotebookNoteWriteRequest,
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<NotebookWireResponse, Self::Error> {
let app = match bus.get_cloned::<AppContext>() {
Ok(app) => app,
Err(error) => return Outcome::Fault(format!("missing app context: {error}")),
};
if let Err(error) = require_session(bus, &app).await {
return Outcome::Fault(error);
}
match crate::notebook::write_note(&app.config, input) {
Ok(note) => Outcome::Next(notebook_json_response(StatusCode::OK, ¬e)),
Err(error) => Outcome::Next(notebook_error_response(&error)),
}
}
}
#[derive(Clone, Copy)]
struct SearchNotebookNotes;
#[async_trait]
impl Transition<(), NotebookWireResponse> for SearchNotebookNotes {
type Error = String;
type Resources = ();
async fn run(
&self,
_state: (),
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<NotebookWireResponse, Self::Error> {
let app = match bus.get_cloned::<AppContext>() {
Ok(app) => app,
Err(error) => return Outcome::Fault(format!("missing app context: {error}")),
};
if let Err(error) = require_session(bus, &app).await {
return Outcome::Fault(error);
}
let query = bus.query_param::<String>("query").unwrap_or_default();
match crate::notebook::search_notes(&app.config, &query) {
Ok(results) => Outcome::Next(notebook_json_response(StatusCode::OK, &results)),
Err(error) => Outcome::Next(notebook_error_response(&error)),
}
}
}
#[derive(Clone, Copy)]
struct ListNotebookAdapters;
#[async_trait]
impl Transition<(), NotebookWireResponse> for ListNotebookAdapters {
type Error = String;
type Resources = ();
async fn run(
&self,
_state: (),
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<NotebookWireResponse, Self::Error> {
let app = match bus.get_cloned::<AppContext>() {
Ok(app) => app,
Err(error) => return Outcome::Fault(format!("missing app context: {error}")),
};
if let Err(error) = require_session(bus, &app).await {
return Outcome::Fault(error);
}
Outcome::Next(notebook_json_response(
StatusCode::OK,
&crate::notebook::adapter_statuses(),
))
}
}
#[derive(Clone, Copy)]
struct RenderNotebookNote;
#[async_trait]
impl Transition<NotebookRenderRequest, NotebookWireResponse> for RenderNotebookNote {
type Error = String;
type Resources = ();
async fn run(
&self,
input: NotebookRenderRequest,
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<NotebookWireResponse, Self::Error> {
let app = match bus.get_cloned::<AppContext>() {
Ok(app) => app,
Err(error) => return Outcome::Fault(format!("missing app context: {error}")),
};
if let Err(error) = require_session(bus, &app).await {
return Outcome::Fault(error);
}
let config = app.config.clone();
let render_result =
match tokio::task::spawn_blocking(move || crate::notebook::render_note(&config, input))
.await
{
Ok(result) => result,
Err(error) => {
return Outcome::Next(notebook_error_response(&format!(
"notebook render worker join failed: {error}"
)));
}
};
match render_result {
Ok(response) => Outcome::Next(notebook_json_response(StatusCode::OK, &response)),
Err(error) => Outcome::Next(notebook_error_response(&error)),
}
}
}
#[derive(Clone, Copy)]
struct GetNotebookIndex;
#[async_trait]
impl Transition<(), NotebookWireResponse> for GetNotebookIndex {
type Error = String;
type Resources = ();
async fn run(
&self,
_state: (),
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<NotebookWireResponse, Self::Error> {
let app = match bus.get_cloned::<AppContext>() {
Ok(app) => app,
Err(error) => return Outcome::Fault(format!("missing app context: {error}")),
};
if let Err(error) = require_session(bus, &app).await {
return Outcome::Fault(error);
}
match crate::notebook::index_status(&app.config) {
Ok(response) => Outcome::Next(notebook_json_response(StatusCode::OK, &response)),
Err(error) => Outcome::Next(notebook_error_response(&error)),
}
}
}
#[derive(Clone, Copy)]
struct BuildNotebookIndex;
#[async_trait]
impl Transition<(), NotebookWireResponse> for BuildNotebookIndex {
type Error = String;
type Resources = ();
async fn run(
&self,
_state: (),
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<NotebookWireResponse, Self::Error> {
let app = match bus.get_cloned::<AppContext>() {
Ok(app) => app,
Err(error) => return Outcome::Fault(format!("missing app context: {error}")),
};
if let Err(error) = require_session(bus, &app).await {
return Outcome::Fault(error);
}
let config = app.config.clone();
let index_result = match tokio::task::spawn_blocking(move || {
crate::notebook::index_notes(&config)
})
.await
{
Ok(result) => result,
Err(error) => {
return Outcome::Next(notebook_error_response(&format!(
"notebook index worker join failed: {error}"
)));
}
};
match index_result {
Ok(response) => Outcome::Next(notebook_json_response(StatusCode::OK, &response)),
Err(error) => Outcome::Next(notebook_error_response(&error)),
}
}
}
#[derive(Clone, Copy)]
struct GetNotebookChunks;
#[async_trait]
impl Transition<(), NotebookWireResponse> for GetNotebookChunks {
type Error = String;
type Resources = ();
async fn run(
&self,
_state: (),
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<NotebookWireResponse, Self::Error> {
let app = match bus.get_cloned::<AppContext>() {
Ok(app) => app,
Err(error) => return Outcome::Fault(format!("missing app context: {error}")),
};
if let Err(error) = require_session(bus, &app).await {
return Outcome::Fault(error);
}
match crate::notebook::chunk_status(&app.config) {
Ok(response) => Outcome::Next(notebook_json_response(StatusCode::OK, &response)),
Err(error) => Outcome::Next(notebook_error_response(&error)),
}
}
}
#[derive(Clone, Copy)]
struct BuildNotebookChunks;
#[async_trait]
impl Transition<(), NotebookWireResponse> for BuildNotebookChunks {
type Error = String;
type Resources = ();
async fn run(
&self,
_state: (),
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<NotebookWireResponse, Self::Error> {
let app = match bus.get_cloned::<AppContext>() {
Ok(app) => app,
Err(error) => return Outcome::Fault(format!("missing app context: {error}")),
};
if let Err(error) = require_session(bus, &app).await {
return Outcome::Fault(error);
}
let config = app.config.clone();
let chunk_result = match tokio::task::spawn_blocking(move || {
crate::notebook::chunk_notes(&config)
})
.await
{
Ok(result) => result,
Err(error) => {
return Outcome::Next(notebook_error_response(&format!(
"notebook chunk worker join failed: {error}"
)));
}
};
match chunk_result {
Ok(response) => Outcome::Next(notebook_json_response(StatusCode::OK, &response)),
Err(error) => Outcome::Next(notebook_error_response(&error)),
}
}
}
#[derive(Clone, Copy)]
struct RetrieveNotebookChunks;
#[async_trait]
impl Transition<(), NotebookWireResponse> for RetrieveNotebookChunks {
type Error = String;
type Resources = ();
async fn run(
&self,
_state: (),
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<NotebookWireResponse, Self::Error> {
let app = match bus.get_cloned::<AppContext>() {
Ok(app) => app,
Err(error) => return Outcome::Fault(format!("missing app context: {error}")),
};
if let Err(error) = require_session(bus, &app).await {
return Outcome::Fault(error);
}
let query = bus.query_param::<String>("query").unwrap_or_default();
let selection = match app.storage.load_provider_selection().await {
Ok(selection) => selection,
Err(error) => {
return Outcome::Fault(format!("failed to load provider selection: {error}"));
}
};
match retrieve_notebook_context(
&app,
&query,
selection.selected_provider.as_deref(),
selection.selected_model_id.as_deref(),
)
.await
{
Ok(response) => Outcome::Next(notebook_json_response(StatusCode::OK, &response)),
Err(error) => Outcome::Next(notebook_error_response(&error)),
}
}
}
#[derive(Clone, Copy)]
struct SearchCorpus;
#[async_trait]
impl Transition<(), NotebookWireResponse> for SearchCorpus {
type Error = String;
type Resources = ();
async fn run(
&self,
_state: (),
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<NotebookWireResponse, Self::Error> {
let app = match bus.get_cloned::<AppContext>() {
Ok(app) => app,
Err(error) => return Outcome::Fault(format!("missing app context: {error}")),
};
if let Err(error) = require_session(bus, &app).await {
return Outcome::Fault(error);
}
let query = bus
.query_param::<String>("q")
.or_else(|| bus.query_param::<String>("query"))
.unwrap_or_default();
let source_type = match parse_search_source_type_filter(
bus.query_param::<String>("source_type").as_deref(),
) {
Ok(source_type) => source_type,
Err(error) => return Outcome::Next(notebook_error_response(&error)),
};
let format = bus.query_param::<String>("format");
let field_scope =
match parse_search_field_scope_filter(bus.query_param::<String>("field").as_deref()) {
Ok(field_scope) => field_scope,
Err(error) => return Outcome::Next(notebook_error_response(&error)),
};
let sort = match parse_search_sort_filter(bus.query_param::<String>("sort").as_deref()) {
Ok(sort) => sort,
Err(error) => return Outcome::Next(notebook_error_response(&error)),
};
let limit = match parse_search_usize_param(
bus.query_param::<String>("limit").as_deref(),
"limit",
) {
Ok(limit) => limit,
Err(error) => return Outcome::Next(notebook_error_response(&error)),
};
let offset = match parse_search_usize_param(
bus.query_param::<String>("offset").as_deref(),
"offset",
) {
Ok(offset) => offset,
Err(error) => return Outcome::Next(notebook_error_response(&error)),
};
let cursor = bus.query_param::<String>("cursor");
let options = SearchIndexQueryOptions {
source_type_filter: source_type,
format_filter: format,
field_scope,
sort,
limit,
offset,
cursor,
};
match search_corpus(&app, &query, options).await {
Ok(response) => Outcome::Next(notebook_json_response(StatusCode::OK, &response)),
Err(error) => Outcome::Next(notebook_error_response(&error)),
}
}
}
#[derive(Clone, Copy)]
struct SearchIndexStatus;
#[async_trait]
impl Transition<(), SearchIndexStatusResponse> for SearchIndexStatus {
type Error = String;
type Resources = ();
async fn run(
&self,
_state: (),
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<SearchIndexStatusResponse, Self::Error> {
let app = match bus.get_cloned::<AppContext>() {
Ok(app) => app,
Err(error) => return Outcome::Fault(format!("missing app context: {error}")),
};
if let Err(error) = require_session(bus, &app).await {
return Outcome::Fault(error);
}
Outcome::Next(crate::search_index::search_index_status(&app.config))
}
}
#[derive(Clone, Copy)]
struct WorkspaceFiles;
#[async_trait]
impl Transition<(), NotebookWireResponse> for WorkspaceFiles {
type Error = String;
type Resources = ();
async fn run(
&self,
_state: (),
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<NotebookWireResponse, Self::Error> {
let app = match bus.get_cloned::<AppContext>() {
Ok(app) => app,
Err(error) => return Outcome::Fault(format!("missing app context: {error}")),
};
if let Err(error) = require_session(bus, &app).await {
return Outcome::Fault(error);
}
let config = app.config.clone();
let raw_path = match decoded_query_param(bus, "path") {
Ok(path) => path,
Err(error) => return Outcome::Next(notebook_error_response(&error)),
};
match tokio::task::spawn_blocking(move || {
crate::workspace::list_workspace_files(&config, raw_path.as_deref())
})
.await
{
Ok(Ok(response)) => Outcome::Next(notebook_json_response(StatusCode::OK, &response)),
Ok(Err(error)) => Outcome::Next(workspace_error_response(&error)),
Err(error) => {
let error = WorkspaceError::upstream(format!(
"workspace file listing worker join failed: {error}"
));
Outcome::Next(workspace_error_response(&error))
}
}
}
}
#[derive(Clone, Copy)]
struct WorkspaceFilePreview;
#[async_trait]
impl Transition<(), NotebookWireResponse> for WorkspaceFilePreview {
type Error = String;
type Resources = ();
async fn run(
&self,
_state: (),
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<NotebookWireResponse, Self::Error> {
let app = match bus.get_cloned::<AppContext>() {
Ok(app) => app,
Err(error) => return Outcome::Fault(format!("missing app context: {error}")),
};
if let Err(error) = require_session(bus, &app).await {
return Outcome::Fault(error);
}
let config = app.config.clone();
let raw_path = match decoded_query_param(bus, "path") {
Ok(path) => path,
Err(error) => return Outcome::Next(notebook_error_response(&error)),
};
match tokio::task::spawn_blocking(move || {
crate::workspace::preview_workspace_file(&config, raw_path.as_deref())
})
.await
{
Ok(Ok(response)) => Outcome::Next(notebook_json_response(StatusCode::OK, &response)),
Ok(Err(error)) => Outcome::Next(workspace_error_response(&error)),
Err(error) => {
let error = WorkspaceError::upstream(format!(
"workspace file preview worker join failed: {error}"
));
Outcome::Next(workspace_error_response(&error))
}
}
}
}
#[derive(Clone, Copy)]
struct PreviewWorkspaceFileChange;
#[async_trait]
impl Transition<WorkspaceFileChangePreviewRequest, NotebookWireResponse>
for PreviewWorkspaceFileChange
{
type Error = String;
type Resources = ();
async fn run(
&self,
input: WorkspaceFileChangePreviewRequest,
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<NotebookWireResponse, Self::Error> {
let app = match bus.get_cloned::<AppContext>() {
Ok(app) => app,
Err(error) => return Outcome::Fault(format!("missing app context: {error}")),
};
let session_id = match require_session(bus, &app).await {
Ok(session_id) => session_id,
Err(error) => return Outcome::Next(notebook_error_response(&error)),
};
let (preview_token, expires_at_ms) = app.new_workspace_file_change_preview_token();
let config = app.config.clone();
let preview_input = input.clone();
match tokio::task::spawn_blocking(move || {
crate::workspace::preview_workspace_file_change(
&config,
&preview_input,
preview_token,
expires_at_ms,
)
})
.await
{
Ok(Ok((response, base_modified_at_ms, base_size_bytes))) => {
app.register_workspace_file_change_preview(
&session_id,
preview_token,
expires_at_ms,
input,
base_modified_at_ms,
base_size_bytes,
);
Outcome::Next(notebook_json_response(StatusCode::OK, &response))
}
Ok(Err(error)) => Outcome::Next(workspace_error_response(&error)),
Err(error) => {
let error = WorkspaceError::upstream(format!(
"workspace file change preview worker join failed: {error}"
));
Outcome::Next(workspace_error_response(&error))
}
}
}
}
#[derive(Clone, Copy)]
struct ApplyWorkspaceFileChange;
#[async_trait]
impl Transition<WorkspaceFileChangeApplyRequest, NotebookWireResponse>
for ApplyWorkspaceFileChange
{
type Error = String;
type Resources = ();
async fn run(
&self,
input: WorkspaceFileChangeApplyRequest,
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<NotebookWireResponse, Self::Error> {
let app = match bus.get_cloned::<AppContext>() {
Ok(app) => app,
Err(error) => return Outcome::Fault(format!("missing app context: {error}")),
};
let session_id = match require_session(bus, &app).await {
Ok(session_id) => session_id,
Err(error) => return Outcome::Next(notebook_error_response(&error)),
};
let preview_token = match Uuid::parse_str(&input.preview_token) {
Ok(preview_token) => preview_token,
Err(_) => {
let error = WorkspaceError::invalid_request(
"workspace file change preview token is invalid",
);
return Outcome::Next(workspace_error_response(&error));
}
};
let record = match app.consume_workspace_file_change_preview(&session_id, preview_token) {
Ok(record) => record,
Err(error) => return Outcome::Next(workspace_error_response(&error)),
};
let audit = match app
.storage
.start_workspace_file_change_audit(&session_id, &record.request, record.base_size_bytes)
.await
{
Ok(audit) => audit,
Err(error) => {
return Outcome::Fault(format!(
"failed to start workspace file change audit: {error}"
));
}
};
let request = record.request;
let base_modified_at_ms = record.base_modified_at_ms;
let base_size_bytes = record.base_size_bytes;
let config = app.config.clone();
let apply_result = tokio::task::spawn_blocking(move || {
crate::workspace::apply_workspace_file_change(
&config,
&request,
base_modified_at_ms,
base_size_bytes,
)
})
.await;
match apply_result {
Ok(Ok(response)) => {
if let Err(error) = app
.storage
.finish_workspace_file_change_audit(
&session_id,
audit.audit_id,
WorkspaceFileChangeAuditStatus::Complete,
None,
None,
response.size_bytes_after,
)
.await
{
warn!("workspace file change applied but audit finalization failed: {error}");
}
Outcome::Next(notebook_json_response(StatusCode::OK, &response))
}
Ok(Err(error)) => {
let error_message = error.to_string();
if let Err(audit_error) = app
.storage
.finish_workspace_file_change_audit(
&session_id,
audit.audit_id,
WorkspaceFileChangeAuditStatus::Failed,
Some(&error_message),
Some(error.api_code()),
base_size_bytes,
)
.await
{
warn!("workspace file change failure audit finalization failed: {audit_error}");
}
Outcome::Next(workspace_error_response(&error))
}
Err(error) => {
let error = WorkspaceError::upstream(format!(
"workspace file change apply worker join failed: {error}"
));
let message = error.to_string();
if let Err(audit_error) = app
.storage
.finish_workspace_file_change_audit(
&session_id,
audit.audit_id,
WorkspaceFileChangeAuditStatus::Failed,
Some(&message),
Some(error.api_code()),
base_size_bytes,
)
.await
{
warn!("workspace file change join audit finalization failed: {audit_error}");
}
Outcome::Next(workspace_error_response(&error))
}
}
}
}
#[derive(Clone, Copy)]
struct ListWorkspaceFileChangeAudits;
#[async_trait]
impl Transition<(), Vec<WorkspaceFileChangeAuditEntry>> for ListWorkspaceFileChangeAudits {
type Error = String;
type Resources = ();
async fn run(
&self,
_state: (),
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<Vec<WorkspaceFileChangeAuditEntry>, Self::Error> {
let app = match bus.get_cloned::<AppContext>() {
Ok(app) => app,
Err(error) => return Outcome::Fault(format!("missing app context: {error}")),
};
let session_id = match require_session(bus, &app).await {
Ok(session_id) => session_id,
Err(error) => return Outcome::Fault(error),
};
let error_code = match optional_non_empty_query_param(bus, "error_code") {
Ok(error_code) => error_code,
Err(error) => return Outcome::Fault(error),
};
match app
.storage
.list_workspace_file_change_audits(&session_id, 50, error_code.as_deref())
.await
{
Ok(audits) => Outcome::Next(audits),
Err(error) => Outcome::Fault(format!("failed to list workspace file audits: {error}")),
}
}
}
#[derive(Clone, Copy)]
struct CreateWorkspaceSourceRoot;
#[async_trait]
impl Transition<WorkspaceSourceRootCreateRequest, NotebookWireResponse>
for CreateWorkspaceSourceRoot
{
type Error = String;
type Resources = ();
async fn run(
&self,
input: WorkspaceSourceRootCreateRequest,
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<NotebookWireResponse, Self::Error> {
let app = match bus.get_cloned::<AppContext>() {
Ok(app) => app,
Err(error) => return Outcome::Fault(format!("missing app context: {error}")),
};
let path = match crate::workspace::resolve_workspace_directory(&app.config, &input.path) {
Ok(path) => path,
Err(error) => return Outcome::Next(workspace_error_response(&error)),
};
let summary = SourceRootSummary {
id: Uuid::new_v4(),
path: path.to_string_lossy().to_string(),
read_only: true,
};
let persisted = match app.storage.upsert_source_root(&summary).await {
Ok(summary) => summary,
Err(error) => return Outcome::Fault(format!("failed to persist source root: {error}")),
};
Outcome::Next(notebook_json_response(
StatusCode::OK,
&app.register_source_root(persisted),
))
}
}
#[derive(Clone, Copy)]
struct RunWorkspaceTask;
#[async_trait]
impl Transition<WorkspaceTaskRunRequest, NotebookWireResponse> for RunWorkspaceTask {
type Error = String;
type Resources = ();
async fn run(
&self,
input: WorkspaceTaskRunRequest,
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<NotebookWireResponse, Self::Error> {
let app = match bus.get_cloned::<AppContext>() {
Ok(app) => app,
Err(error) => return Outcome::Fault(format!("missing app context: {error}")),
};
let config = app.config.clone();
match tokio::task::spawn_blocking(move || {
crate::workspace_tasks::run_workspace_task(&config, input)
})
.await
{
Ok(Ok(response)) => Outcome::Next(notebook_json_response(StatusCode::OK, &response)),
Ok(Err(error)) => Outcome::Next(workspace_task_error_response(&error)),
Err(error) => {
let error = WorkspaceTaskError::upstream(format!(
"workspace task worker join failed: {error}"
));
Outcome::Next(workspace_task_error_response(&error))
}
}
}
}
#[derive(Clone, Copy)]
struct StartWorkspaceTaskRun;
#[async_trait]
impl Transition<WorkspaceTaskRunRequest, NotebookWireResponse> for StartWorkspaceTaskRun {
type Error = String;
type Resources = ();
async fn run(
&self,
input: WorkspaceTaskRunRequest,
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<NotebookWireResponse, Self::Error> {
let app = match bus.get_cloned::<AppContext>() {
Ok(app) => app,
Err(error) => return Outcome::Fault(format!("missing app context: {error}")),
};
let session_id = match require_session(bus, &app).await {
Ok(session_id) => session_id,
Err(error) => return Outcome::Next(notebook_error_response(&error)),
};
let path = match crate::workspace_tasks::workspace_task_response_path(&app.config, &input) {
Ok(path) => path,
Err(error) => return Outcome::Next(workspace_task_error_response(&error)),
};
let (summary, cancel_requested) =
match app.create_workspace_task_run(&session_id, &input, path) {
Ok(created) => created,
Err(error) => return Outcome::Next(workspace_task_error_response(&error)),
};
if let Err(error) = app
.storage
.upsert_workspace_task_run(&session_id, &summary)
.await
{
app.remove_workspace_task_run(&session_id, summary.run_id);
return Outcome::Next(notebook_error_response(&format!(
"failed to persist workspace task run: {error}"
)));
}
let run_id = summary.run_id;
let task_app = app.clone();
let task_session_id = session_id.clone();
let task_config = task_app.config.clone();
tokio::spawn(async move {
let result = tokio::task::spawn_blocking(move || {
crate::workspace_tasks::run_workspace_task_with_cancel(
&task_config,
input,
cancel_requested,
)
})
.await
.unwrap_or_else(|error| {
Err(WorkspaceTaskError::upstream(format!(
"workspace task worker join failed: {error}"
)))
});
if let Some(summary) =
task_app.finish_workspace_task_run(&task_session_id, run_id, result)
{
let _ = task_app
.storage
.upsert_workspace_task_run(&task_session_id, &summary)
.await;
}
});
Outcome::Next(notebook_json_response(StatusCode::OK, &summary))
}
}
#[derive(Clone, Copy)]
struct ListWorkspaceTaskRuns;
#[async_trait]
impl Transition<(), Vec<WorkspaceTaskRunSummary>> for ListWorkspaceTaskRuns {
type Error = String;
type Resources = ();
async fn run(
&self,
_state: (),
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<Vec<WorkspaceTaskRunSummary>, Self::Error> {
let app = match bus.get_cloned::<AppContext>() {
Ok(app) => app,
Err(error) => return Outcome::Fault(format!("missing app context: {error}")),
};
let session_id = match require_session(bus, &app).await {
Ok(session_id) => session_id,
Err(error) => return Outcome::Fault(error),
};
let error_code = match optional_non_empty_query_param(bus, "error_code") {
Ok(error_code) => error_code,
Err(error) => return Outcome::Fault(error),
};
let persisted = match app
.storage
.list_workspace_task_runs(&session_id, 50, error_code.as_deref())
.await
{
Ok(summaries) => summaries,
Err(error) => {
return Outcome::Fault(format!("failed to load workspace task runs: {error}"));
}
};
Outcome::Next(merge_workspace_task_run_summaries(
app.list_workspace_task_runs(&session_id, error_code.as_deref()),
persisted,
))
}
}
#[derive(Clone, Copy)]
struct GetWorkspaceTaskRun;
#[async_trait]
impl Transition<(), NotebookWireResponse> for GetWorkspaceTaskRun {
type Error = String;
type Resources = ();
async fn run(
&self,
_state: (),
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<NotebookWireResponse, Self::Error> {
let app = match bus.get_cloned::<AppContext>() {
Ok(app) => app,
Err(error) => return Outcome::Fault(format!("missing app context: {error}")),
};
let session_id = match require_session(bus, &app).await {
Ok(session_id) => session_id,
Err(error) => return Outcome::Next(notebook_error_response(&error)),
};
let run_id = match workspace_task_run_id_from_bus(bus) {
Ok(run_id) => run_id,
Err(error) => return Outcome::Next(workspace_task_error_response(&error)),
};
if let Some(summary) = app.get_workspace_task_run(&session_id, run_id) {
return Outcome::Next(notebook_json_response(StatusCode::OK, &summary));
}
match app
.storage
.get_workspace_task_run(&session_id, run_id)
.await
{
Ok(Some(summary)) => Outcome::Next(notebook_json_response(StatusCode::OK, &summary)),
Ok(None) => Outcome::Next(workspace_task_error_response(
&WorkspaceTaskError::run_not_found("workspace task run not found"),
)),
Err(error) => Outcome::Next(notebook_error_response(&format!(
"failed to load workspace task run: {error}"
))),
}
}
}
#[derive(Clone, Copy)]
struct CancelWorkspaceTaskRun;
#[async_trait]
impl Transition<(), NotebookWireResponse> for CancelWorkspaceTaskRun {
type Error = String;
type Resources = ();
async fn run(
&self,
_state: (),
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<NotebookWireResponse, Self::Error> {
let app = match bus.get_cloned::<AppContext>() {
Ok(app) => app,
Err(error) => return Outcome::Fault(format!("missing app context: {error}")),
};
let session_id = match require_session(bus, &app).await {
Ok(session_id) => session_id,
Err(error) => return Outcome::Next(notebook_error_response(&error)),
};
let run_id = match workspace_task_run_id_from_bus(bus) {
Ok(run_id) => run_id,
Err(error) => return Outcome::Next(workspace_task_error_response(&error)),
};
if let Some(summary) = app.cancel_workspace_task_run(&session_id, run_id) {
if let Err(error) = app
.storage
.upsert_workspace_task_run(&session_id, &summary)
.await
{
return Outcome::Next(notebook_error_response(&format!(
"failed to persist workspace task cancellation: {error}"
)));
}
return Outcome::Next(notebook_json_response(StatusCode::OK, &summary));
}
match app
.storage
.get_workspace_task_run(&session_id, run_id)
.await
{
Ok(Some(summary)) => Outcome::Next(notebook_json_response(StatusCode::OK, &summary)),
Ok(None) => Outcome::Next(workspace_task_error_response(
&WorkspaceTaskError::run_not_found("workspace task run not found"),
)),
Err(error) => Outcome::Next(notebook_error_response(&format!(
"failed to load workspace task run: {error}"
))),
}
}
}
#[derive(Clone, Copy)]
struct RebuildSearchIndex;
#[async_trait]
impl Transition<(), NotebookWireResponse> for RebuildSearchIndex {
type Error = String;
type Resources = ();
async fn run(
&self,
_state: (),
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<NotebookWireResponse, Self::Error> {
run_search_index_rebuild_transition(bus).await
}
}
#[derive(Clone, Copy)]
struct SyncSearchIndex;
#[async_trait]
impl Transition<(), NotebookWireResponse> for SyncSearchIndex {
type Error = String;
type Resources = ();
async fn run(
&self,
_state: (),
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<NotebookWireResponse, Self::Error> {
run_search_index_sync_transition(bus).await
}
}
#[derive(Clone, Copy)]
struct RecoverSearchIndex;
#[async_trait]
impl Transition<(), NotebookWireResponse> for RecoverSearchIndex {
type Error = String;
type Resources = ();
async fn run(
&self,
_state: (),
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<NotebookWireResponse, Self::Error> {
run_search_index_rebuild_transition(bus).await
}
}
#[derive(Clone, Copy)]
struct ResolveSearchOpenAction;
#[async_trait]
impl Transition<(), SearchOpenActionResponse> for ResolveSearchOpenAction {
type Error = String;
type Resources = ();
async fn run(
&self,
_state: (),
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<SearchOpenActionResponse, Self::Error> {
let app = match bus.get_cloned::<AppContext>() {
Ok(app) => app,
Err(error) => return Outcome::Fault(format!("missing app context: {error}")),
};
if let Err(error) = require_session(bus, &app).await {
return Outcome::Fault(error);
}
let action = bus
.query_param::<String>("action")
.unwrap_or_else(|| "copy_path".to_string());
let path = match decoded_query_param(bus, "path") {
Ok(path) => path.unwrap_or_default(),
Err(error) => return Outcome::Fault(error),
};
match resolve_search_open_action(&app, &path, &action).await {
Ok(response) => Outcome::Next(response),
Err(error) => Outcome::Fault(error),
}
}
}
#[derive(Clone, Copy)]
struct GetNotebookEmbeddings;
#[async_trait]
impl Transition<(), NotebookWireResponse> for GetNotebookEmbeddings {
type Error = String;
type Resources = ();
async fn run(
&self,
_state: (),
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<NotebookWireResponse, Self::Error> {
let app = match bus.get_cloned::<AppContext>() {
Ok(app) => app,
Err(error) => return Outcome::Fault(format!("missing app context: {error}")),
};
if let Err(error) = require_session(bus, &app).await {
return Outcome::Fault(error);
}
let selection = match app.storage.load_provider_selection().await {
Ok(selection) => selection,
Err(error) => {
return Outcome::Fault(format!("failed to load provider selection: {error}"));
}
};
let (provider, model_id) = match (
selection.selected_provider.as_deref(),
selection.selected_model_id.as_deref(),
) {
(Some(provider), Some(model_id)) => (provider, model_id),
_ => {
return Outcome::Next(notebook_error_response(
"provider와 model을 먼저 선택해야 embedding 상태를 확인할 수 있습니다.",
));
}
};
let indexed_source_files = match app.storage.list_indexed_source_files().await {
Ok(files) => files,
Err(error) => {
return Outcome::Fault(format!(
"failed to load indexed source files for embedding status: {error}"
));
}
};
match load_embedding_status(&app.config, &indexed_source_files, provider, model_id).await {
Ok(response) => Outcome::Next(notebook_json_response(StatusCode::OK, &response)),
Err(error) => Outcome::Next(notebook_error_response(&error)),
}
}
}
#[derive(Clone, Copy)]
struct BuildNotebookEmbeddings;
#[async_trait]
impl Transition<(), NotebookWireResponse> for BuildNotebookEmbeddings {
type Error = String;
type Resources = ();
async fn run(
&self,
_state: (),
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<NotebookWireResponse, Self::Error> {
let app = match bus.get_cloned::<AppContext>() {
Ok(app) => app,
Err(error) => return Outcome::Fault(format!("missing app context: {error}")),
};
if let Err(error) = require_session(bus, &app).await {
return Outcome::Fault(error);
}
let selection = match app.storage.load_provider_selection().await {
Ok(selection) => selection,
Err(error) => {
return Outcome::Fault(format!("failed to load provider selection: {error}"));
}
};
let (provider, model_id) = match (
selection.selected_provider.as_deref(),
selection.selected_model_id.as_deref(),
) {
(Some(provider), Some(model_id)) => (provider.to_string(), model_id.to_string()),
_ => {
return Outcome::Next(notebook_error_response(
"provider와 model을 먼저 선택해야 embedding 생성을 시작할 수 있습니다.",
));
}
};
let indexed_source_files = match app.storage.list_indexed_source_files().await {
Ok(files) => files,
Err(error) => {
return Outcome::Fault(format!(
"failed to load indexed source files for embedding build: {error}"
));
}
};
let config = app.config.clone();
let embedding_result =
match build_notebook_embeddings(&config, &indexed_source_files, &provider, &model_id)
.await
{
Ok(response) => response,
Err(error) => return Outcome::Next(notebook_error_response(&error)),
};
Outcome::Next(notebook_json_response(StatusCode::OK, &embedding_result))
}
}
#[derive(Clone, Copy)]
struct ValidateChatInput;
#[async_trait]
impl Transition<ChatSendRequest, ChatSendRequest> for ValidateChatInput {
type Error = String;
type Resources = ();
async fn run(
&self,
input: ChatSendRequest,
_resources: &Self::Resources,
_bus: &mut Bus,
) -> Outcome<ChatSendRequest, Self::Error> {
let trimmed = input.message.trim();
if trimmed.is_empty() {
return Outcome::Fault("message must not be empty".to_string());
}
Outcome::Next(ChatSendRequest {
conversation_id: input.conversation_id,
message: trimmed.to_string(),
})
}
}
#[derive(Clone, Copy)]
struct BuildChatPreview;
#[async_trait]
impl Transition<ChatSendRequest, ChatSendResponse> for BuildChatPreview {
type Error = String;
type Resources = ();
async fn run(
&self,
input: ChatSendRequest,
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<ChatSendResponse, Self::Error> {
let app = match bus.get_cloned::<AppContext>() {
Ok(app) => app,
Err(error) => return Outcome::Fault(format!("missing app context: {error}")),
};
let selection = match app.storage.load_provider_selection().await {
Ok(selection) => selection,
Err(error) => {
return Outcome::Fault(format!("failed to load provider selection: {error}"));
}
};
let selected_provider = selection.selected_provider.clone();
let selected_model_id = selection.selected_model_id.clone();
let session_id = match require_session(bus, &app).await {
Ok(session_id) => session_id,
Err(error) => return Outcome::Fault(error),
};
let retrieval = match collect_chat_retrieval(
&app,
&input.message,
selected_provider.as_deref(),
selected_model_id.as_deref(),
)
.await
{
Ok(retrieval) => retrieval,
Err(error) => ChatRetrievalContext {
provider_message: input.message.clone(),
strategy: "none".to_string(),
result_count: 0,
sources: Vec::new(),
warning: Some(error),
},
};
let conversation_id = input
.conversation_id
.clone()
.unwrap_or_else(|| Uuid::new_v4().to_string());
if let Err(error) = app
.storage
.ensure_conversation(&session_id, &conversation_id, Some(&input.message))
.await
{
return Outcome::Fault(format!("failed to ensure conversation: {error}"));
}
let user_message = match app
.storage
.create_message(&NewConversationMessage {
conversation_id: conversation_id.clone(),
role: "user".to_string(),
content: input.message.clone(),
status: "complete".to_string(),
provider: None,
model_id: None,
})
.await
{
Ok(message) => message,
Err(error) => {
return Outcome::Fault(format!("failed to persist user message: {error}"));
}
};
let assistant_message = match app
.storage
.create_message(&NewConversationMessage {
conversation_id: conversation_id.clone(),
role: "assistant".to_string(),
content: String::new(),
status: "streaming".to_string(),
provider: selected_provider.clone(),
model_id: selected_model_id.clone(),
})
.await
{
Ok(message) => message,
Err(error) => {
return Outcome::Fault(format!("failed to persist assistant placeholder: {error}"));
}
};
let conversation = match app
.storage
.find_conversation(&session_id, &conversation_id)
.await
{
Ok(Some(conversation)) => conversation,
Ok(None) => {
return Outcome::Fault(format!(
"conversation missing after persistence: {conversation_id}"
));
}
Err(error) => return Outcome::Fault(format!("failed to reload conversation: {error}")),
};
publish_chat_stream_event(
&app,
&session_id,
&conversation_id,
ChatStreamEvent::ChatStart {
conversation_id: conversation_id.clone(),
conversation,
user_message: user_message.clone(),
assistant_message: assistant_message.clone(),
selected_provider: selected_provider.clone(),
selected_model_id: selected_model_id.clone(),
},
);
if let Some(warning) = retrieval.warning.as_deref() {
publish_chat_stream_event(
&app,
&session_id,
&conversation_id,
ChatStreamEvent::ChatWarning {
message: warning.to_string(),
},
);
}
let mut warnings = Vec::new();
if let Some(warning) = retrieval.warning.clone() {
warnings.push(warning);
}
let (accepted, reply_preview, provider_warning) = match (
selected_provider.as_deref(),
selected_model_id.as_deref(),
) {
(Some(provider), Some(model_id)) => {
let assistant_message_id = assistant_message.id.clone();
let assistant_message_id_for_worker = assistant_message.id.clone();
let storage = app.storage.clone();
let (delta_tx, mut delta_rx) = tokio::sync::mpsc::unbounded_channel::<String>();
let delta_persist_task = tokio::spawn(async move {
while let Some(delta) = delta_rx.recv().await {
storage
.append_message_delta(&assistant_message_id_for_worker, &delta)
.await
.map_err(|error| {
format!(
"failed to persist assistant delta for {assistant_message_id_for_worker}: {error}"
)
})?;
}
Ok::<(), String>(())
});
let mut streamed_reply = String::new();
match run_provider_chat(provider, model_id, &retrieval.provider_message, |delta| {
streamed_reply.push_str(delta);
let _ = delta_tx.send(delta.to_string());
publish_chat_stream_event(
&app,
&session_id,
&conversation_id,
ChatStreamEvent::ChatDelta {
conversation_id: conversation_id.clone(),
message_id: assistant_message_id.clone(),
selected_provider: Some(provider.to_string()),
selected_model_id: Some(model_id.to_string()),
delta: delta.to_string(),
},
);
})
.await
{
Ok(reply) => {
drop(delta_tx);
match delta_persist_task.await {
Ok(Ok(())) => {}
Ok(Err(error)) => return Outcome::Fault(error),
Err(error) => {
return Outcome::Fault(format!(
"assistant delta worker join failed: {error}"
));
}
}
let persisted_assistant = match app
.storage
.update_message_content(&assistant_message.id, &reply, "complete")
.await
{
Ok(message) => message,
Err(error) => {
return Outcome::Fault(format!(
"failed to persist assistant completion: {error}"
));
}
};
let conversation = match app
.storage
.find_conversation(&session_id, &conversation_id)
.await
{
Ok(Some(conversation)) => conversation,
Ok(None) => {
return Outcome::Fault(format!(
"conversation missing after completion: {conversation_id}"
));
}
Err(error) => {
return Outcome::Fault(format!(
"failed to reload conversation after completion: {error}"
));
}
};
let _ = app
.storage
.record_provider_test(provider, true, "chat inference succeeded")
.await;
publish_chat_stream_event(
&app,
&session_id,
&conversation_id,
ChatStreamEvent::ChatComplete {
conversation_id: conversation_id.clone(),
conversation,
message_id: persisted_assistant.id.clone(),
assistant_message: persisted_assistant,
selected_provider: Some(provider.to_string()),
selected_model_id: Some(model_id.to_string()),
message: reply.clone(),
},
);
(true, reply, None)
}
Err(error) => {
drop(delta_tx);
match delta_persist_task.await {
Ok(Ok(())) => {}
Ok(Err(worker_error)) => return Outcome::Fault(worker_error),
Err(worker_error) => {
return Outcome::Fault(format!(
"assistant delta worker join failed: {worker_error}"
));
}
}
let persisted_error_content =
assistant_error_content(&streamed_reply, &error);
let persisted_assistant = match app
.storage
.update_message_content(
&assistant_message.id,
&persisted_error_content,
"error",
)
.await
{
Ok(message) => message,
Err(storage_error) => {
return Outcome::Fault(format!(
"failed to persist assistant error state: {storage_error}"
));
}
};
let conversation = match app
.storage
.find_conversation(&session_id, &conversation_id)
.await
{
Ok(Some(conversation)) => conversation,
Ok(None) => {
return Outcome::Fault(format!(
"conversation missing after error: {conversation_id}"
));
}
Err(storage_error) => {
return Outcome::Fault(format!(
"failed to reload conversation after error: {storage_error}"
));
}
};
let _ = app
.storage
.record_provider_test(provider, false, &error)
.await;
publish_chat_stream_event(
&app,
&session_id,
&conversation_id,
ChatStreamEvent::ChatError {
conversation_id: conversation_id.clone(),
conversation,
message_id: persisted_assistant.id.clone(),
assistant_message: persisted_assistant,
selected_provider: Some(provider.to_string()),
selected_model_id: Some(model_id.to_string()),
message: error.clone(),
},
);
(false, error.clone(), Some(error))
}
}
}
(Some(provider), None) => (
{
let warning = "model을 먼저 선택해야 실제 provider 호출을 시작할 수 있습니다."
.to_string();
let persisted_assistant = match app
.storage
.update_message_content(&assistant_message.id, &warning, "error")
.await
{
Ok(message) => message,
Err(error) => {
return Outcome::Fault(format!(
"failed to persist assistant selection error: {error}"
));
}
};
let conversation = match app
.storage
.find_conversation(&session_id, &conversation_id)
.await
{
Ok(Some(conversation)) => conversation,
Ok(None) => {
return Outcome::Fault(format!(
"conversation missing after selection error: {conversation_id}"
));
}
Err(error) => {
return Outcome::Fault(format!(
"failed to reload conversation after selection error: {error}"
));
}
};
publish_chat_stream_event(
&app,
&session_id,
&conversation_id,
ChatStreamEvent::ChatError {
conversation_id: conversation_id.clone(),
conversation,
message_id: persisted_assistant.id.clone(),
assistant_message: persisted_assistant,
selected_provider: Some(provider.to_string()),
selected_model_id: None,
message: warning.clone(),
},
);
false
},
"model을 먼저 선택해야 실제 provider 호출을 시작할 수 있습니다.".to_string(),
Some("model을 먼저 선택해야 실제 provider 호출을 시작할 수 있습니다.".to_string()),
),
_ => {
let warning = "provider와 model을 먼저 선택해야 chat runtime을 연결할 수 있습니다."
.to_string();
let persisted_assistant = match app
.storage
.update_message_content(&assistant_message.id, &warning, "error")
.await
{
Ok(message) => message,
Err(error) => {
return Outcome::Fault(format!(
"failed to persist assistant runtime selection error: {error}"
));
}
};
let conversation = match app
.storage
.find_conversation(&session_id, &conversation_id)
.await
{
Ok(Some(conversation)) => conversation,
Ok(None) => {
return Outcome::Fault(format!(
"conversation missing after runtime selection error: {conversation_id}"
));
}
Err(error) => {
return Outcome::Fault(format!(
"failed to reload conversation after runtime selection error: {error}"
));
}
};
publish_chat_stream_event(
&app,
&session_id,
&conversation_id,
ChatStreamEvent::ChatError {
conversation_id: conversation_id.clone(),
conversation,
message_id: persisted_assistant.id.clone(),
assistant_message: persisted_assistant,
selected_provider: None,
selected_model_id: None,
message: warning.clone(),
},
);
(
false,
"provider와 model을 먼저 선택해야 chat runtime을 연결할 수 있습니다."
.to_string(),
Some(
"provider와 model을 먼저 선택해야 chat runtime을 연결할 수 있습니다."
.to_string(),
),
)
}
};
if let Some(warning) = provider_warning {
warnings.push(warning);
}
Outcome::Next(ChatSendResponse {
accepted,
conversation_id,
reply_preview,
user_message_id: user_message.id,
assistant_message_id: assistant_message.id,
selected_provider,
selected_model_id,
retrieval_strategy: retrieval.strategy,
retrieval_result_count: retrieval.result_count,
retrieval_sources: retrieval.sources,
warning: combine_warnings(warnings),
})
}
}
#[derive(Clone, Copy)]
struct StreamChat;
#[async_trait]
impl Transition<(), Sse<ChatStream>> for StreamChat {
type Error = String;
type Resources = ();
async fn run(
&self,
_state: (),
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<Sse<ChatStream>, Self::Error> {
let app = match bus.get_cloned::<AppContext>() {
Ok(app) => app,
Err(error) => return Outcome::Fault(format!("missing app context: {error}")),
};
let session_id = match require_session(bus, &app).await {
Ok(session_id) => session_id,
Err(error) => return Outcome::Fault(error),
};
let conversation_filter = bus.query_param::<String>("conversation_id");
let mut receiver = app.chat_events.subscribe();
let mut ticker = interval(Duration::from_secs(15));
let outbound = stream! {
loop {
tokio::select! {
recv_result = receiver.recv() => {
match recv_result {
Ok(event) => {
if !chat_event_matches_subscription(
&event,
&session_id,
conversation_filter.as_deref(),
) {
continue;
}
yield Ok(SseEvent::default().event("chat").data(event.payload));
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(skipped)) => {
let payload = chat_stream_event_payload(&ChatStreamEvent::ChatWarning {
message: format!("stream lagged and skipped {skipped} events"),
});
yield Ok(SseEvent::default().event("chat").data(payload));
}
Err(tokio::sync::broadcast::error::RecvError::Closed) => {
break;
}
}
}
_ = ticker.tick() => {
let payload = chat_stream_event_payload(&ChatStreamEvent::Heartbeat {
message: "chat stream ready".to_string(),
});
yield Ok(SseEvent::default().event("heartbeat").data(payload));
}
}
}
};
Outcome::Next(Sse::new(Box::pin(outbound)))
}
}
#[derive(Clone, Copy)]
struct ProviderTest;
#[async_trait]
impl Transition<ProviderTestRequest, JsonWireResponse> for ProviderTest {
type Error = String;
type Resources = ();
async fn run(
&self,
input: ProviderTestRequest,
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<JsonWireResponse, Self::Error> {
let app = match bus.get_cloned::<AppContext>() {
Ok(app) => app,
Err(error) => {
return Outcome::Next(json_error_wire_response(
StatusCode::INTERNAL_SERVER_ERROR,
&format!("missing app context: {error}"),
));
}
};
let provider = normalize_provider_id(&input.provider);
let endpoint = match provider.as_str() {
"ollama" => "http://127.0.0.1:11434",
"lmstudio" => "http://127.0.0.1:1234",
other => {
return Outcome::Next(json_error_wire_response(
StatusCode::BAD_REQUEST,
&format!("unsupported provider: {other}"),
));
}
};
let address = match endpoint_to_socket_addr(endpoint) {
Ok(address) => address,
Err(error) => {
return Outcome::Next(json_error_wire_response(StatusCode::BAD_REQUEST, &error));
}
};
let ok = timeout(Duration::from_secs(2), TcpStream::connect(address))
.await
.ok()
.and_then(Result::ok)
.is_some();
let detail = if ok {
"loopback connect succeeded".to_string()
} else {
"loopback connect failed".to_string()
};
if let Err(error) = app
.storage
.record_provider_test(&provider, ok, &detail)
.await
{
return Outcome::Next(json_error_wire_response(
StatusCode::INTERNAL_SERVER_ERROR,
&format!("failed to persist provider status: {error}"),
));
}
Outcome::Next(json_wire_response(
StatusCode::OK,
&ProviderTestResponse {
provider,
endpoint: endpoint.to_string(),
ok,
detail,
},
))
}
}
#[derive(Clone, Copy)]
struct GetProviderSelection;
#[async_trait]
impl Transition<(), ProviderSelectionResponse> for GetProviderSelection {
type Error = String;
type Resources = ();
async fn run(
&self,
_state: (),
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<ProviderSelectionResponse, Self::Error> {
let app = match bus.get_cloned::<AppContext>() {
Ok(app) => app,
Err(error) => return Outcome::Fault(format!("missing app context: {error}")),
};
match app.storage.load_provider_selection().await {
Ok(selection) => Outcome::Next(selection),
Err(error) => Outcome::Fault(format!("failed to load provider selection: {error}")),
}
}
}
#[derive(Clone, Copy)]
struct SetProviderSelection;
#[async_trait]
impl Transition<ProviderSelectionUpdateRequest, JsonWireResponse> for SetProviderSelection {
type Error = String;
type Resources = ();
async fn run(
&self,
input: ProviderSelectionUpdateRequest,
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<JsonWireResponse, Self::Error> {
let app = match bus.get_cloned::<AppContext>() {
Ok(app) => app,
Err(error) => {
return Outcome::Next(json_error_wire_response(
StatusCode::INTERNAL_SERVER_ERROR,
&format!("missing app context: {error}"),
));
}
};
let provider = normalize_provider_id(&input.provider);
if provider.is_empty() {
return match app
.storage
.save_provider_selection(&ProviderSelectionResponse {
selected_provider: None,
selected_model_id: None,
})
.await
{
Ok(selection) => Outcome::Next(json_wire_response(StatusCode::OK, &selection)),
Err(error) => Outcome::Next(json_error_wire_response(
StatusCode::INTERNAL_SERVER_ERROR,
&format!("failed to clear provider selection: {error}"),
)),
};
}
if !matches!(provider.as_str(), "ollama" | "lmstudio") {
return Outcome::Next(json_error_wire_response(
StatusCode::BAD_REQUEST,
&format!("unsupported provider: {}", input.provider),
));
}
let model_id = input
.model_id
.map(|model| model.trim().to_string())
.filter(|model| !model.is_empty());
match app
.storage
.save_provider_selection(&ProviderSelectionResponse {
selected_provider: Some(provider),
selected_model_id: model_id,
})
.await
{
Ok(selection) => Outcome::Next(json_wire_response(StatusCode::OK, &selection)),
Err(error) => Outcome::Next(json_error_wire_response(
StatusCode::INTERNAL_SERVER_ERROR,
&format!("failed to persist provider selection: {error}"),
)),
}
}
}
pub fn bootstrap_redirect_axon() -> Axon<(), RedirectWithCookie, String> {
Axon::simple::<String>("bootstrap-redirect").then(BootstrapRedirect)
}
pub fn health_axon() -> Axon<(), HealthResponse, String> {
Axon::simple::<String>("health").then(Health)
}
pub fn init_axon() -> Axon<(), AppInitResponse, String> {
Axon::simple::<String>("app-init").then(Init)
}
pub fn providers_axon() -> Axon<(), Vec<ProviderSummary>, String> {
Axon::simple::<String>("providers").then(Providers)
}
pub fn provider_models_axon() -> Axon<(), serde_json::Value, String> {
Axon::simple::<String>("provider-models").then(ProviderModels)
}
pub fn provider_models_error_response(error: &String) -> HttpResponse {
let status = if error.contains("authentication required") {
StatusCode::UNAUTHORIZED
} else if error.contains("missing provider query parameter")
|| error.contains("unsupported provider")
{
StatusCode::BAD_REQUEST
} else {
StatusCode::BAD_GATEWAY
};
json_error_response(status, error)
}
fn normalize_provider_id(provider: &str) -> String {
provider.trim().to_lowercase()
}
fn decoded_query_param(bus: &Bus, name: &str) -> Result<Option<String>, String> {
let Some(raw_value) = bus.query_param::<String>(name) else {
return Ok(None);
};
let encoded_pair = format!("{name}={raw_value}");
let Some((_, decoded)) = url::form_urlencoded::parse(encoded_pair.as_bytes()).next() else {
return Err(format!("invalid query parameter: {name}"));
};
Ok(Some(decoded.into_owned()))
}
fn optional_non_empty_query_param(bus: &Bus, name: &str) -> Result<Option<String>, String> {
Ok(decoded_query_param(bus, name)?.and_then(|value| {
let value = value.trim();
(!value.is_empty()).then(|| value.to_string())
}))
}
fn required_decoded_query_param(
bus: &Bus,
name: &str,
missing_error: &str,
) -> Result<String, String> {
decoded_query_param(bus, name)?.ok_or_else(|| missing_error.to_string())
}
fn json_wire_response<T: Serialize>(status: StatusCode, payload: &T) -> JsonWireResponse {
match serde_json::to_vec(payload) {
Ok(body) => JsonWireResponse {
status: status.as_u16(),
body,
},
Err(error) => json_error_wire_response(
StatusCode::INTERNAL_SERVER_ERROR,
&format!("failed to serialize json response: {error}"),
),
}
}
fn json_error_wire_response(status: StatusCode, error: &str) -> JsonWireResponse {
let body = serde_json::to_vec(&api_error_body(status, status_error_code(status), error))
.unwrap_or_else(|_| format!(r#"{{"error":"{}"}}"#, error).into_bytes());
JsonWireResponse {
status: status.as_u16(),
body,
}
}
fn api_error_body(status: StatusCode, code: &str, error: &str) -> ApiErrorResponse {
ApiErrorResponse {
error: error.to_string(),
code: code.to_string(),
status: status.as_u16(),
}
}
fn notebook_json_response<T: Serialize>(status: StatusCode, payload: &T) -> NotebookWireResponse {
match serde_json::to_vec(payload) {
Ok(body) => NotebookWireResponse {
status: status.as_u16(),
content_type: "application/json".to_string(),
body,
},
Err(error) => {
notebook_error_response(&format!("failed to serialize notebook response: {error}"))
}
}
}
fn notebook_error_response(error: &str) -> NotebookWireResponse {
let (status, code) = classify_notebook_error(error);
notebook_api_error_response(status, code, error)
}
fn notebook_api_error_response(
status: StatusCode,
code: &str,
error: &str,
) -> NotebookWireResponse {
notebook_json_response(status, &api_error_body(status, code, error))
}
fn workspace_error_response(error: &WorkspaceError) -> NotebookWireResponse {
let status =
StatusCode::from_u16(error.status_code()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
notebook_api_error_response(status, error.api_code(), error.message())
}
fn workspace_task_error_response(error: &WorkspaceTaskError) -> NotebookWireResponse {
let status =
StatusCode::from_u16(error.status_code()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
notebook_api_error_response(status, error.api_code(), error.message())
}
fn classify_notebook_error(error: &str) -> (StatusCode, &'static str) {
if error.contains("authentication required") {
return (StatusCode::UNAUTHORIZED, "auth_required");
}
if error.contains("not found") || error.contains("missing or stale") {
return (StatusCode::NOT_FOUND, "not_found");
}
if error.contains("already exists") {
return (StatusCode::CONFLICT, "conflict");
}
if error.contains("timed out")
|| error.contains("failed to run Typst adapter")
|| error.contains("worker join failed")
{
return (StatusCode::BAD_GATEWAY, "upstream_error");
}
if error.contains("unsupported")
|| error.contains("required")
|| error.contains("must")
|| error.contains("invalid")
|| error.contains("blocked")
|| error.contains("선택해야")
|| error.contains("only ")
|| error.contains("Typst adapter is ")
{
return (StatusCode::BAD_REQUEST, "invalid_request");
}
(StatusCode::INTERNAL_SERVER_ERROR, "internal_error")
}
fn status_error_code(status: StatusCode) -> &'static str {
match status {
StatusCode::BAD_REQUEST => "invalid_request",
StatusCode::UNAUTHORIZED => "auth_required",
StatusCode::FORBIDDEN => "forbidden",
StatusCode::NOT_FOUND => "not_found",
StatusCode::CONFLICT => "conflict",
StatusCode::BAD_GATEWAY => "upstream_error",
StatusCode::INTERNAL_SERVER_ERROR => "internal_error",
_ if status.is_client_error() => "client_error",
_ if status.is_server_error() => "server_error",
_ => "request_error",
}
}
pub fn provider_selection_get_axon() -> Axon<(), ProviderSelectionResponse, String> {
Axon::simple::<String>("provider-selection-get").then(GetProviderSelection)
}
pub fn provider_selection_set_axon()
-> Axon<ProviderSelectionUpdateRequest, JsonWireResponse, String> {
Axon::typed::<ProviderSelectionUpdateRequest, String>("provider-selection-set")
.then(SetProviderSelection)
}
pub fn source_roots_list_axon() -> Axon<(), Vec<SourceRootSummary>, String> {
Axon::simple::<String>("source-roots-list").then(ListSourceRoots)
}
pub fn conversations_list_axon() -> Axon<(), Vec<ConversationSummary>, String> {
Axon::simple::<String>("conversations-list").then(ListConversations)
}
pub fn conversations_create_axon() -> Axon<ConversationCreateRequest, ConversationSummary, String> {
Axon::typed::<ConversationCreateRequest, String>("conversations-create")
.then(CreateConversation)
}
pub fn conversation_messages_axon() -> Axon<(), Vec<ConversationMessage>, String> {
Axon::simple::<String>("conversation-messages").then(ListConversationMessages)
}
pub fn conversation_delete_axon() -> Axon<(), ConversationDeleteResponse, String> {
Axon::simple::<String>("conversation-delete").then(DeleteConversation)
}
pub fn notebook_notes_list_axon() -> Axon<(), NotebookWireResponse, String> {
Axon::simple::<String>("notebook-notes-list").then(ListNotebookNotes)
}
pub fn notebook_note_read_axon() -> Axon<(), NotebookWireResponse, String> {
Axon::simple::<String>("notebook-note-read").then(ReadNotebookNote)
}
pub fn notebook_artifact_axon() -> Axon<(), NotebookWireResponse, String> {
Axon::simple::<String>("notebook-artifact-read").then(ReadNotebookArtifact)
}
pub fn notebook_note_create_axon() -> Axon<NotebookNoteCreateRequest, NotebookWireResponse, String>
{
Axon::typed::<NotebookNoteCreateRequest, String>("notebook-note-create")
.then(CreateNotebookNote)
}
pub fn notebook_note_write_axon() -> Axon<NotebookNoteWriteRequest, NotebookWireResponse, String> {
Axon::typed::<NotebookNoteWriteRequest, String>("notebook-note-write").then(WriteNotebookNote)
}
pub fn notebook_notes_search_axon() -> Axon<(), NotebookWireResponse, String> {
Axon::simple::<String>("notebook-notes-search").then(SearchNotebookNotes)
}
pub fn notebook_adapters_list_axon() -> Axon<(), NotebookWireResponse, String> {
Axon::simple::<String>("notebook-adapters-list").then(ListNotebookAdapters)
}
pub fn notebook_note_render_axon() -> Axon<NotebookRenderRequest, NotebookWireResponse, String> {
Axon::typed::<NotebookRenderRequest, String>("notebook-note-render").then(RenderNotebookNote)
}
pub fn notebook_index_get_axon() -> Axon<(), NotebookWireResponse, String> {
Axon::simple::<String>("notebook-index-get").then(GetNotebookIndex)
}
pub fn notebook_index_build_axon() -> Axon<(), NotebookWireResponse, String> {
Axon::simple::<String>("notebook-index-build").then(BuildNotebookIndex)
}
pub fn notebook_chunks_get_axon() -> Axon<(), NotebookWireResponse, String> {
Axon::simple::<String>("notebook-chunks-get").then(GetNotebookChunks)
}
pub fn notebook_chunks_build_axon() -> Axon<(), NotebookWireResponse, String> {
Axon::simple::<String>("notebook-chunks-build").then(BuildNotebookChunks)
}
pub fn notebook_embeddings_get_axon() -> Axon<(), NotebookWireResponse, String> {
Axon::simple::<String>("notebook-embeddings-get").then(GetNotebookEmbeddings)
}
pub fn notebook_embeddings_build_axon() -> Axon<(), NotebookWireResponse, String> {
Axon::simple::<String>("notebook-embeddings-build").then(BuildNotebookEmbeddings)
}
pub fn notebook_retrieve_axon() -> Axon<(), NotebookWireResponse, String> {
Axon::simple::<String>("notebook-retrieve").then(RetrieveNotebookChunks)
}
pub fn search_axon() -> Axon<(), NotebookWireResponse, String> {
Axon::simple::<String>("search").then(SearchCorpus)
}
pub fn search_index_status_axon() -> Axon<(), SearchIndexStatusResponse, String> {
Axon::simple::<String>("search-index-status").then(SearchIndexStatus)
}
pub fn workspace_files_axon() -> Axon<(), NotebookWireResponse, String> {
Axon::simple::<String>("workspace-files").then(WorkspaceFiles)
}
pub fn workspace_file_preview_axon() -> Axon<(), NotebookWireResponse, String> {
Axon::simple::<String>("workspace-file-preview").then(WorkspaceFilePreview)
}
pub fn workspace_file_change_preview_axon()
-> Axon<WorkspaceFileChangePreviewRequest, NotebookWireResponse, String> {
Axon::typed::<WorkspaceFileChangePreviewRequest, String>("workspace-file-change-preview")
.then(PreviewWorkspaceFileChange)
}
pub fn workspace_file_change_apply_axon()
-> Axon<WorkspaceFileChangeApplyRequest, NotebookWireResponse, String> {
Axon::typed::<WorkspaceFileChangeApplyRequest, String>("workspace-file-change-apply")
.then(ApplyWorkspaceFileChange)
}
pub fn workspace_file_change_audits_list_axon()
-> Axon<(), Vec<WorkspaceFileChangeAuditEntry>, String> {
Axon::simple::<String>("workspace-file-change-audits-list").then(ListWorkspaceFileChangeAudits)
}
pub fn workspace_source_root_create_axon()
-> Axon<WorkspaceSourceRootCreateRequest, NotebookWireResponse, String> {
Axon::typed::<WorkspaceSourceRootCreateRequest, String>("workspace-source-root-create")
.then(CreateWorkspaceSourceRoot)
}
pub fn workspace_task_run_axon() -> Axon<WorkspaceTaskRunRequest, NotebookWireResponse, String> {
Axon::typed::<WorkspaceTaskRunRequest, String>("workspace-task-run").then(RunWorkspaceTask)
}
pub fn workspace_task_run_start_axon() -> Axon<WorkspaceTaskRunRequest, NotebookWireResponse, String>
{
Axon::typed::<WorkspaceTaskRunRequest, String>("workspace-task-run-start")
.then(StartWorkspaceTaskRun)
}
pub fn workspace_task_runs_list_axon() -> Axon<(), Vec<WorkspaceTaskRunSummary>, String> {
Axon::simple::<String>("workspace-task-runs-list").then(ListWorkspaceTaskRuns)
}
pub fn workspace_task_run_get_axon() -> Axon<(), NotebookWireResponse, String> {
Axon::simple::<String>("workspace-task-run-get").then(GetWorkspaceTaskRun)
}
pub fn workspace_task_run_cancel_axon() -> Axon<(), NotebookWireResponse, String> {
Axon::simple::<String>("workspace-task-run-cancel").then(CancelWorkspaceTaskRun)
}
pub fn search_index_rebuild_axon() -> Axon<(), NotebookWireResponse, String> {
Axon::simple::<String>("search-index-rebuild").then(RebuildSearchIndex)
}
pub fn search_index_sync_axon() -> Axon<(), NotebookWireResponse, String> {
Axon::simple::<String>("search-index-sync").then(SyncSearchIndex)
}
pub fn search_index_recover_axon() -> Axon<(), NotebookWireResponse, String> {
Axon::simple::<String>("search-index-recover").then(RecoverSearchIndex)
}
pub fn search_open_action_axon() -> Axon<(), SearchOpenActionResponse, String> {
Axon::simple::<String>("search-open-action").then(ResolveSearchOpenAction)
}
pub fn source_roots_create_axon() -> Axon<SourceRootCreateInput, NotebookWireResponse, String> {
Axon::typed::<SourceRootCreateInput, String>("source-roots-create").then(CreateSourceRoot)
}
pub fn ingest_jobs_axon() -> Axon<(), Vec<IngestJobSummary>, String> {
Axon::simple::<String>("ingest-jobs").then(ListIngestJobs)
}
pub fn ingest_status_axon() -> Axon<(), IngestStatusResponse, String> {
Axon::simple::<String>("ingest-status").then(GetIngestStatus)
}
pub fn ingest_rescan_axon() -> Axon<IngestRescanRequest, NotebookWireResponse, String> {
Axon::typed::<IngestRescanRequest, String>("ingest-rescan").then(RescanIngest)
}
pub fn chat_send_axon() -> Axon<ChatSendRequest, ChatSendResponse, String> {
Axon::typed::<ChatSendRequest, String>("chat-send")
.then(ValidateChatInput)
.then(BuildChatPreview)
}
pub fn stream_chat_axon() -> Axon<(), Sse<ChatStream>, String> {
Axon::simple::<String>("chat-stream").then(StreamChat)
}
pub fn provider_test_axon() -> Axon<ProviderTestRequest, JsonWireResponse, String> {
Axon::typed::<ProviderTestRequest, String>("provider-test").then(ProviderTest)
}
async fn session_id_from_bus(bus: &mut Bus, app: &AppContext) -> Option<String> {
let jar = bus.get_cloned::<CookieJar>().ok()?;
let session_id = jar.get("soma_studio_session")?.to_string();
if app.lookup_session(&session_id).is_some() {
return Some(session_id);
}
match app.storage.session_exists(&session_id).await.ok()? {
true => {
app.remember_session(session_id.clone())?;
Some(session_id)
}
false => None,
}
}
fn conversation_id_from_bus(bus: &mut Bus) -> Result<String, String> {
if let Ok(conversation_id) = bus.path_param::<String>("id") {
return Ok(conversation_id);
}
let path = bus
.get_cloned::<RequestPath>()
.map_err(|_| "Missing or invalid path parameter: id".to_string())?
.0;
let segments: Vec<&str> = path.trim_matches('/').split('/').collect();
match segments.as_slice() {
["api", "conversations", conversation_id] => Ok((*conversation_id).to_string()),
["api", "conversations", conversation_id, "messages"] => Ok((*conversation_id).to_string()),
_ => Err("Missing or invalid path parameter: id".to_string()),
}
}
fn workspace_task_run_id_from_bus(bus: &mut Bus) -> Result<Uuid, WorkspaceTaskError> {
if let Ok(run_id) = bus.path_param::<String>("id") {
return Uuid::parse_str(&run_id).map_err(|_| {
WorkspaceTaskError::invalid_request("Missing or invalid path parameter: id")
});
}
let path = bus
.get_cloned::<RequestPath>()
.map_err(|_| WorkspaceTaskError::invalid_request("Missing or invalid path parameter: id"))?
.0;
let segments: Vec<&str> = path.trim_matches('/').split('/').collect();
let run_id = match segments.as_slice() {
["api", "workspace", "task-runs", run_id] => *run_id,
["api", "workspace", "task-runs", run_id, "cancel"] => *run_id,
_ => {
return Err(WorkspaceTaskError::invalid_request(
"Missing or invalid path parameter: id",
));
}
};
Uuid::parse_str(run_id)
.map_err(|_| WorkspaceTaskError::invalid_request("Missing or invalid path parameter: id"))
}
fn merge_workspace_task_run_summaries(
memory: Vec<WorkspaceTaskRunSummary>,
persisted: Vec<WorkspaceTaskRunSummary>,
) -> Vec<WorkspaceTaskRunSummary> {
let mut seen = HashSet::new();
let mut merged = Vec::new();
for summary in memory.into_iter().chain(persisted) {
if seen.insert(summary.run_id) {
merged.push(summary);
}
}
merged.sort_by_key(|summary| std::cmp::Reverse(summary.started_at_ms));
merged.truncate(50);
merged
}
fn publish_chat_stream_event(
app: &AppContext,
session_id: &str,
conversation_id: &str,
event: ChatStreamEvent,
) {
app.publish_chat_event(chat_stream_event_envelope(
session_id,
conversation_id,
&event,
));
}
fn assistant_error_content(streamed_reply: &str, error: &str) -> String {
if streamed_reply.is_empty() {
return error.to_string();
}
streamed_reply.to_string()
}
#[derive(Debug, Clone)]
struct ChatRetrievalContext {
provider_message: String,
strategy: String,
result_count: usize,
sources: Vec<ChatRetrievalSource>,
warning: Option<String>,
}
async fn collect_chat_retrieval(
app: &AppContext,
message: &str,
provider: Option<&str>,
model_id: Option<&str>,
) -> Result<ChatRetrievalContext, String> {
let message = message.to_string();
let retrieval_result = retrieve_notebook_context(app, &message, provider, model_id).await?;
if retrieval_result.results.is_empty() {
return Ok(ChatRetrievalContext {
provider_message: message,
strategy: "none".to_string(),
result_count: 0,
sources: Vec::new(),
warning: None,
});
}
Ok(ChatRetrievalContext {
provider_message: build_chat_message_with_retrieval(&message, &retrieval_result),
strategy: retrieval_result.strategy.clone(),
result_count: retrieval_result.results.len(),
sources: retrieval_result
.results
.iter()
.take(3)
.map(|result| ChatRetrievalSource {
path: result.path.clone(),
format: result.format.clone(),
chunk_index: result.chunk_index,
score: result.score,
snippet: result.snippet.clone(),
provenance: result.provenance.clone(),
})
.collect(),
warning: None,
})
}
async fn retrieve_notebook_context(
app: &AppContext,
query: &str,
provider: Option<&str>,
model_id: Option<&str>,
) -> Result<soma_studio_core::NotebookRetrievalResponse, String> {
let indexed_source_files = app
.storage
.list_indexed_source_files()
.await
.map_err(|error| format!("failed to load indexed source files for retrieval: {error}"))?;
if let (Some(provider), Some(model_id)) = (provider, model_id) {
match run_provider_embeddings(provider, model_id, &[query.to_string()]).await {
Ok(vectors) if !vectors.is_empty() => {
let query_string = query.to_string();
let provider = provider.to_string();
let model_id = model_id.to_string();
let vector = vectors[0].clone();
let notebook_config = app.config.clone();
let notebook_query = query_string.clone();
let notebook_provider = provider.clone();
let notebook_model_id = model_id.clone();
let notebook_vector = vector.clone();
let notebook_semantic = tokio::task::spawn_blocking(move || {
crate::notebook::retrieve_notes_with_query_vector(
¬ebook_config,
¬ebook_query,
¬ebook_provider,
¬ebook_model_id,
¬ebook_vector,
)
});
let source_root_config = app.config.clone();
let source_root_files = indexed_source_files.clone();
let source_root_provider = provider.clone();
let source_root_model_id = model_id.clone();
let source_root_vector = vector;
let source_root_semantic = tokio::task::spawn_blocking(move || {
crate::ingest::retrieve_source_root_with_query_vector(
&source_root_config,
&source_root_files,
&source_root_provider,
&source_root_model_id,
&source_root_vector,
)
});
let notebook_semantic = notebook_semantic.await.map_err(|error| {
format!("notebook semantic retrieval worker join failed: {error}")
})??;
let source_root_semantic = source_root_semantic.await.map_err(|error| {
format!("source-root semantic retrieval worker join failed: {error}")
})??;
let merged = merge_retrieval_responses(notebook_semantic, source_root_semantic);
if !merged.results.is_empty() {
return Ok(merged);
}
}
Ok(_) => {}
Err(_) => {}
}
}
retrieve_lexical_corpus(app, query).await
}
async fn search_corpus(
app: &AppContext,
query: &str,
options: SearchIndexQueryOptions,
) -> Result<SearchResponse, String> {
crate::search_index::search_index_query(app.config.clone(), query.to_string(), options).await
}
async fn retrieve_lexical_corpus(
app: &AppContext,
query: &str,
) -> Result<soma_studio_core::NotebookRetrievalResponse, String> {
retrieve_lexical_corpus_with_profile(app, query, SearchProfile::RagContext).await
}
async fn retrieve_lexical_corpus_with_profile(
app: &AppContext,
query: &str,
profile: SearchProfile,
) -> Result<soma_studio_core::NotebookRetrievalResponse, String> {
let indexed_source_files = app
.storage
.list_indexed_source_files()
.await
.map_err(|error| format!("failed to load indexed source files for retrieval: {error}"))?;
let notebook_config = app.config.clone();
let query_string = query.to_string();
let notebook_query = query_string.clone();
let notebook_lexical = tokio::task::spawn_blocking(move || match profile {
SearchProfile::RagContext => {
crate::notebook::retrieve_notes(¬ebook_config, ¬ebook_query)
}
SearchProfile::InteractiveSearch => {
crate::notebook::retrieve_notes_with_profile(¬ebook_config, ¬ebook_query, profile)
}
});
let source_root_config = app.config.clone();
let source_root_files = indexed_source_files.clone();
let source_root_lexical = tokio::task::spawn_blocking(move || match profile {
SearchProfile::RagContext => crate::ingest::retrieve_source_root_text(
&source_root_config,
&source_root_files,
&query_string,
),
SearchProfile::InteractiveSearch => crate::ingest::retrieve_source_root_text_with_profile(
&source_root_config,
&source_root_files,
&query_string,
profile,
),
});
let notebook_lexical = notebook_lexical
.await
.map_err(|error| format!("notebook lexical retrieval worker join failed: {error}"))??;
let source_root_lexical = source_root_lexical
.await
.map_err(|error| format!("source-root retrieval worker join failed: {error}"))??;
Ok(merge_retrieval_responses_with_limit(
notebook_lexical,
source_root_lexical,
profile.default_limit(),
))
}
fn parse_search_source_type_filter(
filter: Option<&str>,
) -> Result<Option<SearchSourceType>, String> {
match filter.map(str::trim).filter(|filter| !filter.is_empty()) {
None => Ok(None),
Some("notebook") => Ok(Some(SearchSourceType::Notebook)),
Some("source_root") | Some("source-root") => Ok(Some(SearchSourceType::SourceRoot)),
Some(filter) => Err(format!("unsupported search source_type: {filter}")),
}
}
fn parse_search_field_scope_filter(
filter: Option<&str>,
) -> Result<Option<SearchFieldScope>, String> {
match filter.map(str::trim).filter(|filter| !filter.is_empty()) {
None => Ok(None),
Some("all") => Ok(Some(SearchFieldScope::All)),
Some("title") => Ok(Some(SearchFieldScope::Title)),
Some("body") => Ok(Some(SearchFieldScope::Body)),
Some("path") => Ok(Some(SearchFieldScope::Path)),
Some(filter) => Err(format!("unsupported search field: {filter}")),
}
}
fn parse_search_sort_filter(filter: Option<&str>) -> Result<Option<SearchSort>, String> {
match filter.map(str::trim).filter(|filter| !filter.is_empty()) {
None => Ok(None),
Some("relevance") => Ok(Some(SearchSort::Relevance)),
Some("updated_at") | Some("updated-at") => Ok(Some(SearchSort::UpdatedAt)),
Some("indexed_at") | Some("indexed-at") => Ok(Some(SearchSort::IndexedAt)),
Some(filter) => Err(format!("unsupported search sort: {filter}")),
}
}
fn parse_search_usize_param(value: Option<&str>, name: &str) -> Result<Option<usize>, String> {
match value.map(str::trim).filter(|value| !value.is_empty()) {
None => Ok(None),
Some(value) => value
.parse::<usize>()
.map(Some)
.map_err(|_| format!("unsupported search {name}: {value}")),
}
}
async fn run_search_index_rebuild_transition(
bus: &mut Bus,
) -> Outcome<NotebookWireResponse, String> {
let app = match bus.get_cloned::<AppContext>() {
Ok(app) => app,
Err(error) => return Outcome::Fault(format!("missing app context: {error}")),
};
if let Err(error) = require_session(bus, &app).await {
return Outcome::Fault(error);
}
match rebuild_search_index_for_app(&app).await {
Ok(response) => Outcome::Next(notebook_json_response(StatusCode::OK, &response)),
Err(error) => Outcome::Next(notebook_error_response(&error)),
}
}
async fn run_search_index_sync_transition(bus: &mut Bus) -> Outcome<NotebookWireResponse, String> {
let app = match bus.get_cloned::<AppContext>() {
Ok(app) => app,
Err(error) => return Outcome::Fault(format!("missing app context: {error}")),
};
if let Err(error) = require_session(bus, &app).await {
return Outcome::Fault(error);
}
match sync_search_index_for_app(&app).await {
Ok(response) => Outcome::Next(notebook_json_response(StatusCode::OK, &response)),
Err(error) => Outcome::Next(notebook_error_response(&error)),
}
}
async fn rebuild_search_index_for_app(
app: &AppContext,
) -> Result<soma_studio_core::SearchIndexRebuildResponse, String> {
let indexed_source_files = app
.storage
.list_indexed_source_files()
.await
.map_err(|error| {
format!("failed to load indexed source files for search index: {error}")
})?;
crate::search_index::rebuild_search_index(app.config.clone(), indexed_source_files).await
}
async fn sync_search_index_for_app(
app: &AppContext,
) -> Result<soma_studio_core::SearchIndexRebuildResponse, String> {
let indexed_source_files = app
.storage
.list_indexed_source_files()
.await
.map_err(|error| {
format!("failed to load indexed source files for search index sync: {error}")
})?;
crate::search_index::sync_search_index(app.config.clone(), indexed_source_files).await
}
async fn resolve_search_open_action(
app: &AppContext,
path: &str,
action: &str,
) -> Result<SearchOpenActionResponse, String> {
if action != "copy_path" {
return Err(format!("unsupported search open action: {action}"));
}
let (source_root_id, relative_path) = parse_source_root_result_path(path)?;
let source_roots = app
.storage
.list_source_roots()
.await
.map_err(|error| format!("failed to load source roots: {error}"))?;
let source_root = source_roots
.iter()
.find(|root| root.id.to_string() == source_root_id)
.ok_or_else(|| format!("source root is not registered: {source_root_id}"))?;
let root = PathBuf::from(&source_root.path);
let candidate = root.join(Path::new(&relative_path));
let canonical_root = root
.canonicalize()
.map_err(|error| format!("failed to canonicalize source root: {error}"))?;
let canonical_candidate = candidate
.canonicalize()
.map_err(|error| format!("failed to canonicalize source path: {error}"))?;
if !canonical_candidate.starts_with(&canonical_root) {
return Err("resolved source path escapes the registered source root".to_string());
}
Ok(SearchOpenActionResponse {
action: action.to_string(),
source_type: SearchSourceType::SourceRoot,
path: path.to_string(),
canonical_path: canonical_candidate.to_string_lossy().to_string(),
allowed: true,
})
}
fn parse_source_root_result_path(path: &str) -> Result<(String, String), String> {
let normalized = path.trim().replace('\\', "/");
let Some(rest) = normalized.strip_prefix("source-root/") else {
return Err("search open action only supports source-root results".to_string());
};
let Some((source_root_id, relative_path)) = rest.split_once('/') else {
return Err("source-root result path must include a relative file path".to_string());
};
if source_root_id.trim().is_empty() || relative_path.trim().is_empty() {
return Err("source-root result path is incomplete".to_string());
}
if relative_path
.split('/')
.any(|segment| segment == ".." || segment.is_empty())
{
return Err("source-root result path contains an unsafe segment".to_string());
}
Ok((source_root_id.to_string(), relative_path.to_string()))
}
fn merge_retrieval_responses(
primary: soma_studio_core::NotebookRetrievalResponse,
secondary: soma_studio_core::NotebookRetrievalResponse,
) -> soma_studio_core::NotebookRetrievalResponse {
merge_retrieval_responses_with_limit(
primary,
secondary,
SearchProfile::RagContext.default_limit(),
)
}
fn merge_retrieval_responses_with_limit(
primary: soma_studio_core::NotebookRetrievalResponse,
secondary: soma_studio_core::NotebookRetrievalResponse,
limit: usize,
) -> soma_studio_core::NotebookRetrievalResponse {
let primary_has_results = !primary.results.is_empty();
let secondary_has_results = !secondary.results.is_empty();
let query = if primary.query.is_empty() {
secondary.query.clone()
} else {
primary.query.clone()
};
let mut results = primary.results;
results.extend(secondary.results);
results.sort_by(|left, right| {
right
.score
.cmp(&left.score)
.then_with(|| left.path.cmp(&right.path))
.then_with(|| left.chunk_index.cmp(&right.chunk_index))
});
results.truncate(limit);
let strategy = match (primary_has_results, secondary_has_results) {
(false, false) => "none",
(true, false) => primary.strategy.as_str(),
(false, true) => secondary.strategy.as_str(),
(true, true) => match (primary.strategy.as_str(), secondary.strategy.as_str()) {
("semantic", "semantic") => "semantic",
("lexical", "lexical") => "lexical",
_ => "hybrid",
},
};
soma_studio_core::NotebookRetrievalResponse {
query,
strategy: strategy.to_string(),
results,
}
}
fn build_chat_message_with_retrieval(
message: &str,
retrieval: &soma_studio_core::NotebookRetrievalResponse,
) -> String {
const MAX_CONTEXT_RESULTS: usize = 3;
let context = retrieval
.results
.iter()
.take(MAX_CONTEXT_RESULTS)
.map(|result| {
format!(
"[{path} | chunk {chunk} | score {score}]\n{snippet}",
path = result.path,
chunk = result.chunk_index + 1,
score = result.score,
snippet = result.snippet
)
})
.collect::<Vec<_>>()
.join("\n\n");
format!(
"You are answering inside Soma Studio.\nUse the notebook context when it is relevant to the request.\nIf the notebook context is insufficient, say so briefly instead of inventing details.\n\nNotebook context:\n{context}\n\nUser request:\n{message}"
)
}
fn combine_warnings(warnings: Vec<String>) -> Option<String> {
let warnings = warnings
.into_iter()
.map(|warning| warning.trim().to_string())
.filter(|warning| !warning.is_empty())
.collect::<Vec<_>>();
if warnings.is_empty() {
None
} else {
Some(warnings.join(" | "))
}
}
async fn require_session(bus: &mut Bus, app: &AppContext) -> Result<String, String> {
session_id_from_bus(bus, app)
.await
.ok_or_else(|| "authentication required".to_string())
}
fn validate_same_origin(bus: &mut Bus) -> Result<(), String> {
let origin = bus
.get_cloned::<RequestOrigin>()
.ok()
.and_then(|origin| origin.0)
.ok_or_else(|| "missing Origin header".to_string())?;
let host = bus
.get_cloned::<RequestHost>()
.ok()
.and_then(|host| host.0)
.ok_or_else(|| "missing Host header".to_string())?;
let expected = format!("http://{host}");
if origin != expected {
return Err(format!(
"origin mismatch: expected {expected}, got {origin}"
));
}
Ok(())
}
fn endpoint_to_socket_addr(endpoint: &str) -> Result<std::net::SocketAddr, String> {
let url =
Url::parse(endpoint).map_err(|error| format!("invalid endpoint '{endpoint}': {error}"))?;
match (url.host_str(), url.port_or_known_default()) {
(Some(host), Some(port)) if matches!(host, "127.0.0.1" | "localhost" | "::1") => {
format!("{host}:{port}")
.parse()
.map_err(|error| format!("invalid loopback socket for '{endpoint}': {error}"))
}
_ => Err(format!("provider endpoint must be loopback: {endpoint}")),
}
}
async fn run_provider_chat<F>(
provider: &str,
model_id: &str,
message: &str,
on_delta: F,
) -> Result<String, String>
where
F: FnMut(&str),
{
match provider {
"ollama" => run_ollama_chat(model_id, message, on_delta).await,
"lmstudio" => run_lmstudio_chat(model_id, message, on_delta).await,
other => Err(format!("unsupported provider: {other}")),
}
}
async fn fetch_ollama_models() -> Result<Vec<ProviderModelSummary>, String> {
#[derive(Deserialize)]
struct OllamaTagsResponse {
models: Vec<OllamaModel>,
}
#[derive(Deserialize)]
struct OllamaModel {
name: String,
model: Option<String>,
}
let body = fetch_loopback_json("http://127.0.0.1:11434", "/api/tags").await?;
let parsed: OllamaTagsResponse = serde_json::from_slice(&body)
.map_err(|error| format!("invalid Ollama response: {error}"))?;
Ok(parsed
.models
.into_iter()
.map(|model| ProviderModelSummary {
id: model.model.unwrap_or_else(|| model.name.clone()),
label: model.name,
provider: "ollama".to_string(),
})
.collect())
}
async fn fetch_lmstudio_models() -> Result<Vec<ProviderModelSummary>, String> {
#[derive(Deserialize)]
struct LmStudioModelsResponse {
data: Vec<LmStudioModel>,
}
#[derive(Deserialize)]
struct LmStudioModel {
id: String,
}
let body = fetch_loopback_json("http://127.0.0.1:1234", "/v1/models").await?;
let parsed: LmStudioModelsResponse = serde_json::from_slice(&body)
.map_err(|error| format!("invalid LM Studio response: {error}"))?;
Ok(parsed
.data
.into_iter()
.map(|model| ProviderModelSummary {
id: model.id.clone(),
label: model.id,
provider: "lmstudio".to_string(),
})
.collect())
}
async fn run_ollama_chat<F>(
model_id: &str,
message: &str,
mut on_delta: F,
) -> Result<String, String>
where
F: FnMut(&str),
{
#[derive(Serialize)]
struct OllamaChatRequest<'a> {
model: &'a str,
messages: [OllamaMessageRequest<'a>; 1],
stream: bool,
}
#[derive(Serialize)]
struct OllamaMessageRequest<'a> {
role: &'a str,
content: &'a str,
}
#[derive(Deserialize)]
struct OllamaChatResponse {
message: OllamaMessageResponse,
}
#[derive(Deserialize)]
struct OllamaMessageResponse {
content: String,
}
let payload = serde_json::to_vec(&OllamaChatRequest {
model: model_id,
messages: [OllamaMessageRequest {
role: "user",
content: message,
}],
stream: true,
})
.map_err(|error| format!("failed to serialize Ollama chat request: {error}"))?;
let mut accumulated = String::new();
let mut buffer = String::new();
let response = send_loopback_stream_request(
"http://127.0.0.1:11434",
"POST",
"/api/chat",
&payload,
|chunk| {
buffer.push_str(&String::from_utf8_lossy(chunk));
process_ollama_stream_buffer(&mut buffer, &mut accumulated, &mut on_delta)?;
Ok(())
},
)
.await?;
if !buffer.trim().is_empty() {
process_ollama_stream_buffer(&mut buffer, &mut accumulated, &mut on_delta)?;
}
if response.status != 200 {
return Err(format!(
"Ollama returned HTTP {}: {}",
response.status,
String::from_utf8_lossy(&response.body)
));
}
if !accumulated.is_empty() {
return Ok(accumulated);
}
let parsed: OllamaChatResponse = serde_json::from_slice(&response.body)
.map_err(|error| format!("invalid Ollama chat response: {error}"))?;
Ok(parsed.message.content)
}
async fn run_lmstudio_chat<F>(
model_id: &str,
message: &str,
mut on_delta: F,
) -> Result<String, String>
where
F: FnMut(&str),
{
#[derive(Serialize)]
struct LmStudioChatRequest<'a> {
model: &'a str,
messages: [LmStudioMessageRequest<'a>; 1],
stream: bool,
}
#[derive(Serialize)]
struct LmStudioMessageRequest<'a> {
role: &'a str,
content: &'a str,
}
#[derive(Deserialize)]
struct LmStudioChatResponse {
choices: Vec<LmStudioChoice>,
}
#[derive(Deserialize)]
struct LmStudioChoice {
message: LmStudioChoiceMessage,
}
#[derive(Deserialize)]
struct LmStudioChoiceMessage {
content: Option<String>,
}
let payload = serde_json::to_vec(&LmStudioChatRequest {
model: model_id,
messages: [LmStudioMessageRequest {
role: "user",
content: message,
}],
stream: true,
})
.map_err(|error| format!("failed to serialize LM Studio chat request: {error}"))?;
let mut accumulated = String::new();
let mut buffer = String::new();
let response = send_loopback_stream_request(
"http://127.0.0.1:1234",
"POST",
"/v1/chat/completions",
&payload,
|chunk| {
buffer.push_str(&String::from_utf8_lossy(chunk));
process_lmstudio_sse_buffer(&mut buffer, &mut accumulated, &mut on_delta)?;
Ok(())
},
)
.await?;
if !buffer.trim().is_empty() {
buffer.push_str("\n\n");
process_lmstudio_sse_buffer(&mut buffer, &mut accumulated, &mut on_delta)?;
}
if response.status != 200 {
return Err(format!(
"LM Studio returned HTTP {}: {}",
response.status,
String::from_utf8_lossy(&response.body)
));
}
if !accumulated.is_empty() {
return Ok(accumulated);
}
let parsed: LmStudioChatResponse = serde_json::from_slice(&response.body)
.map_err(|error| format!("invalid LM Studio chat response: {error}"))?;
parsed
.choices
.into_iter()
.next()
.and_then(|choice| choice.message.content)
.ok_or_else(|| "LM Studio chat response did not include assistant content".to_string())
}
async fn fetch_loopback_json(endpoint: &str, path: &str) -> Result<Vec<u8>, String> {
let response = send_loopback_stream_request(endpoint, "GET", path, b"", |_| Ok(())).await?;
if response.status != 200 {
return Err(format!(
"unexpected HTTP {} from {endpoint}{path}: {}",
response.status,
String::from_utf8_lossy(&response.body)
));
}
Ok(response.body)
}
async fn post_loopback_json(endpoint: &str, path: &str, payload: &[u8]) -> Result<Vec<u8>, String> {
let response =
send_loopback_stream_request(endpoint, "POST", path, payload, |_| Ok(())).await?;
if response.status != 200 {
return Err(format!(
"loopback POST {endpoint}{path} returned HTTP {}: {}",
response.status,
String::from_utf8_lossy(&response.body)
));
}
Ok(response.body)
}
async fn run_provider_embeddings(
provider: &str,
model_id: &str,
inputs: &[String],
) -> Result<Vec<Vec<f32>>, String> {
match provider {
"ollama" => run_ollama_embeddings(model_id, inputs).await,
"lmstudio" => run_lmstudio_embeddings(model_id, inputs).await,
other => Err(format!("unsupported provider: {other}")),
}
}
async fn run_ollama_embeddings(model_id: &str, inputs: &[String]) -> Result<Vec<Vec<f32>>, String> {
#[derive(Serialize)]
struct OllamaEmbedRequest<'a> {
model: &'a str,
input: &'a [String],
}
#[derive(Deserialize)]
struct OllamaEmbedResponse {
embeddings: Vec<Vec<f32>>,
}
let payload = serde_json::to_vec(&OllamaEmbedRequest {
model: model_id,
input: inputs,
})
.map_err(|error| format!("failed to serialize Ollama embed request: {error}"))?;
let body = post_loopback_json("http://127.0.0.1:11434", "/api/embed", &payload).await?;
let parsed: OllamaEmbedResponse = serde_json::from_slice(&body)
.map_err(|error| format!("invalid Ollama embed response: {error}"))?;
Ok(parsed.embeddings)
}
async fn run_lmstudio_embeddings(
model_id: &str,
inputs: &[String],
) -> Result<Vec<Vec<f32>>, String> {
#[derive(Serialize)]
struct LmStudioEmbeddingRequest<'a> {
model: &'a str,
input: &'a [String],
}
#[derive(Deserialize)]
struct LmStudioEmbeddingResponse {
data: Vec<LmStudioEmbeddingItem>,
}
#[derive(Deserialize)]
struct LmStudioEmbeddingItem {
embedding: Vec<f32>,
index: usize,
}
let payload = serde_json::to_vec(&LmStudioEmbeddingRequest {
model: model_id,
input: inputs,
})
.map_err(|error| format!("failed to serialize LM Studio embed request: {error}"))?;
let body = post_loopback_json("http://127.0.0.1:1234", "/v1/embeddings", &payload).await?;
let parsed: LmStudioEmbeddingResponse = serde_json::from_slice(&body)
.map_err(|error| format!("invalid LM Studio embed response: {error}"))?;
let mut data = parsed.data;
data.sort_by_key(|item| item.index);
Ok(data.into_iter().map(|item| item.embedding).collect())
}
async fn build_notebook_embeddings(
config: &soma_studio_core::AppConfig,
indexed_source_files: &[crate::storage::IndexedSourceFileRow],
provider: &str,
model_id: &str,
) -> Result<NotebookEmbeddingResponse, String> {
let config = config.clone();
let collect_config = config.clone();
let collect_source_root_config = config.clone();
let provider = provider.to_string();
let model_id = model_id.to_string();
let chunk_inputs = tokio::task::spawn_blocking(move || {
crate::notebook::collect_embedding_inputs(&collect_config)
})
.await
.map_err(|error| format!("notebook embedding input worker join failed: {error}"))??;
let source_root_files = indexed_source_files.to_vec();
let source_root_inputs = tokio::task::spawn_blocking(move || {
crate::ingest::collect_source_root_embedding_inputs(
&collect_source_root_config,
&source_root_files,
)
})
.await
.map_err(|error| format!("source-root embedding input worker join failed: {error}"))??;
let mut items = Vec::new();
for input in chunk_inputs {
let vectors = run_provider_embeddings(&provider, &model_id, &input.chunks).await?;
validate_embedding_vector_count(
"notebook",
&input.path,
input.chunks.len(),
vectors.len(),
)?;
let config = config.clone();
let provider = provider.clone();
let model_id = model_id.clone();
let path = input.path.clone();
let written = tokio::task::spawn_blocking(move || {
crate::notebook::write_note_embeddings(&config, &path, &provider, &model_id, &vectors)
})
.await
.map_err(|error| format!("notebook embedding write worker join failed: {error}"))??;
items.push(written);
}
for input in source_root_inputs {
let vectors = run_provider_embeddings(&provider, &model_id, &input.chunks).await?;
let source_path = format!(
"source-root/{}/{}",
input.source_root_id, input.relative_path
);
validate_embedding_vector_count(
"source-root",
&source_path,
input.chunks.len(),
vectors.len(),
)?;
let config = config.clone();
let provider = provider.clone();
let model_id = model_id.clone();
let source_root_id = input.source_root_id.clone();
let relative_path = input.relative_path.clone();
let written = tokio::task::spawn_blocking(move || {
crate::ingest::write_source_root_embeddings(
&config,
&source_root_id,
&relative_path,
&provider,
&model_id,
&vectors,
)
})
.await
.map_err(|error| format!("source-root embedding write worker join failed: {error}"))??;
items.push(written);
}
items.sort_by(|left, right| left.path.cmp(&right.path));
Ok(NotebookEmbeddingResponse {
embedded: items.len(),
items,
})
}
fn validate_embedding_vector_count(
scope: &str,
path: &str,
expected: usize,
actual: usize,
) -> Result<(), String> {
if expected == actual {
return Ok(());
}
Err(format!(
"{scope} embedding vector count mismatch for {path}: expected {expected}, got {actual}"
))
}
async fn load_embedding_status(
config: &soma_studio_core::AppConfig,
indexed_source_files: &[crate::storage::IndexedSourceFileRow],
provider: &str,
model_id: &str,
) -> Result<NotebookEmbeddingResponse, String> {
let notebook_config = config.clone();
let notebook_provider = provider.to_string();
let notebook_model_id = model_id.to_string();
let notebook = tokio::task::spawn_blocking(move || {
crate::notebook::embedding_status(¬ebook_config, ¬ebook_provider, ¬ebook_model_id)
})
.await
.map_err(|error| format!("notebook embedding status worker join failed: {error}"))??;
let source_root_config = config.clone();
let source_root_provider = provider.to_string();
let source_root_model_id = model_id.to_string();
let source_root_files = indexed_source_files.to_vec();
let source_root = tokio::task::spawn_blocking(move || {
crate::ingest::source_root_embedding_status(
&source_root_config,
&source_root_files,
&source_root_provider,
&source_root_model_id,
)
})
.await
.map_err(|error| format!("source-root embedding status worker join failed: {error}"))??;
Ok(merge_embedding_responses(notebook, source_root))
}
fn merge_embedding_responses(
primary: NotebookEmbeddingResponse,
secondary: NotebookEmbeddingResponse,
) -> NotebookEmbeddingResponse {
let mut items = primary.items;
items.extend(secondary.items);
items.sort_by(|left, right| {
left.path
.cmp(&right.path)
.then_with(|| left.provider.cmp(&right.provider))
.then_with(|| left.model_id.cmp(&right.model_id))
});
NotebookEmbeddingResponse {
embedded: items.len(),
items,
}
}
struct LoopbackResponse {
status: u16,
body: Vec<u8>,
}
async fn send_loopback_stream_request<F>(
endpoint: &str,
method: &str,
path: &str,
body: &[u8],
mut on_chunk: F,
) -> Result<LoopbackResponse, String>
where
F: FnMut(&[u8]) -> Result<(), String>,
{
let url =
Url::parse(endpoint).map_err(|error| format!("invalid endpoint '{endpoint}': {error}"))?;
let host = url
.host_str()
.ok_or_else(|| format!("missing host in endpoint {endpoint}"))?;
let address = endpoint_to_socket_addr(endpoint)?;
let stream = timeout(Duration::from_secs(3), TcpStream::connect(address))
.await
.map_err(|_| format!("timed out connecting to {endpoint}"))?
.map_err(|error| format!("failed to connect to {endpoint}: {error}"))?;
let io = hyper_util::rt::TokioIo::new(stream);
let (mut sender, connection) = http1::handshake(io).await.map_err(|error| {
format!("failed to start HTTP client connection to {endpoint}: {error}")
})?;
tokio::spawn(async move {
let _ = connection.await;
});
let request = Request::builder()
.method(method)
.uri(path)
.header("Host", host)
.header("Accept", "application/json")
.header("Content-Type", "application/json")
.header("Connection", "close")
.body(Full::new(Bytes::copy_from_slice(body)))
.map_err(|error| format!("failed to build request for {endpoint}: {error}"))?;
let response = sender
.send_request(request)
.await
.map_err(|error| format!("failed to send request to {endpoint}: {error}"))?;
let status = response.status().as_u16();
let mut response_body = response.into_body();
let mut collected = Vec::new();
while let Some(frame_result) = response_body.frame().await {
let frame = frame_result
.map_err(|error| format!("failed to stream response from {endpoint}: {error}"))?;
if let Some(data) = frame.data_ref() {
let bytes = data.as_ref();
if !bytes.is_empty() {
on_chunk(bytes)?;
collected.extend_from_slice(bytes);
}
}
}
Ok(LoopbackResponse {
status,
body: collected,
})
}
fn process_ollama_stream_buffer<F>(
buffer: &mut String,
accumulated: &mut String,
on_delta: &mut F,
) -> Result<(), String>
where
F: FnMut(&str),
{
#[derive(Deserialize)]
struct OllamaStreamChunk {
message: Option<OllamaChunkMessage>,
}
#[derive(Deserialize)]
struct OllamaChunkMessage {
content: Option<String>,
}
while let Some(line_end) = buffer.find('\n') {
let line = buffer[..line_end].trim().to_string();
buffer.drain(..=line_end);
if line.is_empty() {
continue;
}
let parsed: OllamaStreamChunk = serde_json::from_str(&line)
.map_err(|error| format!("invalid Ollama stream chunk: {error}"))?;
if let Some(content) = parsed.message.and_then(|message| message.content) {
accumulated.push_str(&content);
on_delta(&content);
}
}
Ok(())
}
fn process_lmstudio_sse_buffer<F>(
buffer: &mut String,
accumulated: &mut String,
on_delta: &mut F,
) -> Result<(), String>
where
F: FnMut(&str),
{
#[derive(Deserialize)]
struct LmStudioDeltaEnvelope {
choices: Vec<LmStudioDeltaChoice>,
}
#[derive(Deserialize)]
struct LmStudioDeltaChoice {
delta: LmStudioDelta,
}
#[derive(Deserialize)]
struct LmStudioDelta {
content: Option<String>,
}
loop {
let separator_index = if let Some(index) = buffer.find("\r\n\r\n") {
Some((index, 4))
} else {
buffer.find("\n\n").map(|index| (index, 2))
};
let Some((index, separator_len)) = separator_index else {
break;
};
let block = buffer[..index].to_string();
buffer.drain(..index + separator_len);
for line in block.lines() {
let line = line.trim();
if !line.starts_with("data:") {
continue;
}
let data = line.trim_start_matches("data:").trim();
if data.is_empty() || data == "[DONE]" {
continue;
}
let parsed: LmStudioDeltaEnvelope = serde_json::from_str(data)
.map_err(|error| format!("invalid LM Studio stream chunk: {error}"))?;
if let Some(content) = parsed
.choices
.into_iter()
.next()
.and_then(|choice| choice.delta.content)
{
accumulated.push_str(&content);
on_delta(&content);
}
}
}
Ok(())
}
fn merge_provider_summaries(
mut providers: Vec<ProviderSummary>,
statuses: Vec<crate::storage::ProviderStatusRow>,
) -> Vec<ProviderSummary> {
for provider in &mut providers {
if let Some(status) = statuses
.iter()
.find(|status| status.provider == provider.id)
{
provider.last_test_ok = status.last_test_ok;
provider.last_test_detail = status.last_test_detail.clone();
provider.last_tested_at = status.last_tested_at.clone();
}
}
providers
}
fn build_plain_response(
status: StatusCode,
body: String,
headers: Vec<(String, String)>,
) -> HttpResponse {
let mut builder = Response::builder()
.status(status)
.header(CONTENT_TYPE, "text/plain; charset=utf-8");
for (name, value) in headers {
builder = builder.header(name, value);
}
builder
.body(
Full::new(Bytes::from(body))
.map_err(|never| match never {})
.boxed(),
)
.expect("response builder should be infallible")
}
#[cfg(test)]
mod tests {
use super::{
assistant_error_content, build_chat_message_with_retrieval, combine_warnings,
merge_retrieval_responses, merge_retrieval_responses_with_limit, normalize_provider_id,
parse_source_root_result_path, provider_models_error_response,
validate_embedding_vector_count,
};
use http::StatusCode;
use soma_studio_core::{NotebookRetrievalResponse, NotebookRetrievalResult};
#[test]
fn assistant_error_content_prefers_partial_reply_when_present() {
assert_eq!(
assistant_error_content("partial answer", "provider failed"),
"partial answer"
);
assert_eq!(
assistant_error_content("", "provider failed"),
"provider failed"
);
}
#[test]
fn source_root_open_action_path_parser_rejects_unsafe_paths() {
assert_eq!(
parse_source_root_result_path("source-root/root-a/docs/topic.md").expect("path"),
("root-a".to_string(), "docs/topic.md".to_string())
);
assert!(parse_source_root_result_path("notebook/docs/topic.md").is_err());
assert!(parse_source_root_result_path("source-root/root-a/../secret.md").is_err());
assert!(parse_source_root_result_path("source-root/root-a//topic.md").is_err());
}
#[test]
fn retrieval_prompt_includes_chunk_context_and_request() {
let prompt = build_chat_message_with_retrieval(
"summarize this note",
&NotebookRetrievalResponse {
query: "summarize this note".to_string(),
strategy: "semantic".to_string(),
results: vec![NotebookRetrievalResult {
path: "notes/topic.md".to_string(),
format: "markdown".to_string(),
chunk_path: "notebook-chunks/notes/topic.json".to_string(),
chunk_index: 0,
score: 3,
snippet: "Important notebook context".to_string(),
provenance: "source=notebook/notes/topic.md".to_string(),
}],
},
);
assert!(prompt.contains("Notebook context:"));
assert!(prompt.contains("notes/topic.md"));
assert!(prompt.contains("Important notebook context"));
assert!(prompt.contains("User request:\nsummarize this note"));
}
#[test]
fn combine_warnings_merges_non_empty_messages() {
assert_eq!(
combine_warnings(vec!["one".to_string(), "two".to_string()]),
Some("one | two".to_string())
);
assert_eq!(combine_warnings(vec!["".to_string()]), None);
}
#[test]
fn provider_model_errors_keep_client_and_upstream_failures_separate() {
let missing =
provider_models_error_response(&"missing provider query parameter".to_string());
assert_eq!(missing.status(), StatusCode::BAD_REQUEST);
let unsupported =
provider_models_error_response(&"unsupported provider: custom".to_string());
assert_eq!(unsupported.status(), StatusCode::BAD_REQUEST);
let invalid_upstream =
provider_models_error_response(&"invalid Ollama response: expected value".to_string());
assert_eq!(invalid_upstream.status(), StatusCode::BAD_GATEWAY);
let connection_failed = provider_models_error_response(&"connection refused".to_string());
assert_eq!(connection_failed.status(), StatusCode::BAD_GATEWAY);
}
#[test]
fn provider_id_normalization_is_shared_across_provider_routes() {
assert_eq!(normalize_provider_id(" Ollama "), "ollama");
assert_eq!(normalize_provider_id(" LMSTUDIO "), "lmstudio");
assert_eq!(normalize_provider_id(" "), "");
}
#[test]
fn retrieval_merge_preserves_single_strategy_when_only_one_side_has_hits() {
let merged = merge_retrieval_responses(
NotebookRetrievalResponse {
query: "topic".to_string(),
strategy: "semantic".to_string(),
results: Vec::new(),
},
NotebookRetrievalResponse {
query: "topic".to_string(),
strategy: "lexical".to_string(),
results: vec![NotebookRetrievalResult {
path: "source-root/root/topic.md".to_string(),
format: "markdown".to_string(),
chunk_path: "source-root-chunks/root/topic.json".to_string(),
chunk_index: 0,
score: 2,
snippet: "fallback lexical".to_string(),
provenance: "source=source-root/root/topic.md".to_string(),
}],
},
);
assert_eq!(merged.strategy, "lexical");
assert_eq!(merged.results.len(), 1);
}
#[test]
fn retrieval_merge_marks_hybrid_when_semantic_and_lexical_hits_both_exist() {
let merged = merge_retrieval_responses(
NotebookRetrievalResponse {
query: "topic".to_string(),
strategy: "semantic".to_string(),
results: vec![NotebookRetrievalResult {
path: "notes/topic.md".to_string(),
format: "markdown".to_string(),
chunk_path: "notebook-chunks/notes/topic.json".to_string(),
chunk_index: 0,
score: 900,
snippet: "semantic".to_string(),
provenance: "source=notebook/notes/topic.md".to_string(),
}],
},
NotebookRetrievalResponse {
query: "topic".to_string(),
strategy: "lexical".to_string(),
results: vec![NotebookRetrievalResult {
path: "source-root/root/topic.md".to_string(),
format: "markdown".to_string(),
chunk_path: "source-root-chunks/root/topic.json".to_string(),
chunk_index: 0,
score: 2,
snippet: "lexical".to_string(),
provenance: "source=source-root/root/topic.md".to_string(),
}],
},
);
assert_eq!(merged.strategy, "hybrid");
assert_eq!(merged.results.len(), 2);
}
#[test]
fn retrieval_merge_respects_explicit_limit_for_interactive_search() {
let primary_results = (0..10)
.map(|index| NotebookRetrievalResult {
path: format!("notes/{index}.md"),
format: "markdown".to_string(),
chunk_path: format!("notebook-chunks/notes/{index}.json"),
chunk_index: 0,
score: 10 - index,
snippet: "artifact".to_string(),
provenance: "source=notebook".to_string(),
})
.collect();
let merged = merge_retrieval_responses_with_limit(
NotebookRetrievalResponse {
query: "artifact".to_string(),
strategy: "lexical".to_string(),
results: primary_results,
},
NotebookRetrievalResponse {
query: "artifact".to_string(),
strategy: "none".to_string(),
results: Vec::new(),
},
25,
);
assert_eq!(merged.results.len(), 10);
}
#[test]
fn embedding_vector_count_mismatch_is_reported_before_artifact_write() {
assert!(validate_embedding_vector_count("notebook", "notes/a.md", 2, 2).is_ok());
let error = validate_embedding_vector_count("source-root", "source-root/root/a.md", 2, 1)
.expect_err("mismatch error");
assert!(error.contains("source-root embedding vector count mismatch"));
assert!(error.contains("expected 2, got 1"));
}
}