pub mod api;
mod handlers;
mod settings;
use handlers::*;
use settings::*;
pub use settings::{ServerRuntimeState, classify_server_runtime_state, load_server_settings};
pub mod auth;
pub mod graph_id;
pub mod identity;
pub mod policy;
pub mod queries;
pub mod registry;
pub mod workload;
pub use graph_id::GraphId;
pub use identity::{AuthSource, GraphKey, ResolvedActor, Scope, TenantId};
pub use registry::{GraphHandle, GraphRegistry, InsertError, RegistryLookup, RegistrySnapshot};
use crate::queries::{QueryRegistry, check, format_check_breakages};
use std::collections::{BTreeMap, HashMap, HashSet};
use std::fs;
use std::io;
use std::io::Write;
use std::path::PathBuf;
use std::sync::Arc;
use api::{
BranchCreateOutput, BranchCreateRequest, BranchDeleteOutput, BranchListOutput,
BranchMergeOutput, BranchMergeRequest, ChangeOutput, ChangeRequest, CommitListOutput,
CommitListQuery, ErrorCode, ErrorOutput, ExportRequest, GraphInfo, GraphListResponse,
HealthOutput, IngestOutput, IngestRequest, InvokeStoredQueryRequest, InvokeStoredQueryResponse,
QueriesCatalogOutput, QueryRequest, ReadOutput, ReadRequest, SchemaApplyOutput,
SchemaApplyRequest, SchemaOutput, SnapshotQuery, ingest_output, schema_apply_output,
snapshot_payload,
};
pub use auth::{AWS_SECRET_ENV, EnvOrFileTokenSource, TokenSource, resolve_token_source};
use axum::body::{Body, Bytes};
use axum::extract::DefaultBodyLimit;
use axum::extract::{Extension, OriginalUri, Path, Query, Request, State};
use axum::http::StatusCode;
use axum::http::header::{AUTHORIZATION, CONTENT_TYPE, HeaderName, HeaderValue};
use axum::middleware::{self, Next};
use axum::response::{IntoResponse, Response};
use axum::routing::{delete, get, post};
use axum::{Json, Router};
use color_eyre::eyre::{Result, WrapErr, bail, eyre};
use futures::stream;
use omnigraph::db::{Omnigraph, ReadTarget};
use omnigraph::error::{ManifestConflictDetails, ManifestErrorKind, OmniError};
use omnigraph::storage::normalize_root_uri;
use omnigraph_compiler::catalog::Catalog;
use omnigraph_compiler::json_params_to_param_map;
use omnigraph_compiler::query::parser::parse_query;
use omnigraph_compiler::{JsonParamMode, ParamMap};
pub use policy::{
PolicyAction, PolicyCompiler, PolicyConfig, PolicyDecision, PolicyEngine, PolicyExpectation,
PolicyRequest, PolicyResourceKind, PolicyTestConfig,
};
use serde::Deserialize;
use serde_json::Value;
use sha2::{Digest, Sha256};
use subtle::ConstantTimeEq;
use tokio::net::TcpListener;
use tokio::sync::mpsc;
use tower_http::trace::TraceLayer;
use tracing::{error, info, warn};
use tracing_subscriber::EnvFilter;
use utoipa::OpenApi;
use utoipa::openapi::path::{Parameter, ParameterIn};
use utoipa::openapi::schema::{Object, Type};
use utoipa::openapi::security::{Http, HttpAuthScheme, SecurityScheme};
type BearerTokenHash = [u8; 32];
fn hash_bearer_token(token: &str) -> BearerTokenHash {
let digest = Sha256::digest(token.as_bytes());
let mut out = [0u8; 32];
out.copy_from_slice(&digest);
out
}
#[derive(OpenApi)]
#[openapi(
info(
title = "Omnigraph API",
description = "HTTP API for the Omnigraph graph database",
),
paths(
handlers::server_health,
handlers::server_graphs_list,
handlers::server_snapshot,
// deprecated; the #[deprecated] attribute on the handler
// surfaces as `deprecated: true` on the OpenAPI operation.
#[allow(deprecated)] handlers::server_read,
handlers::server_query,
handlers::server_export,
#[allow(deprecated)] handlers::server_change,
handlers::server_mutate,
handlers::server_list_queries,
handlers::server_invoke_query,
handlers::server_schema_apply,
handlers::server_schema_get,
handlers::server_load,
// deprecated; the #[deprecated] attribute on the handler surfaces as
// `deprecated: true` on the OpenAPI operation.
#[allow(deprecated)] handlers::server_ingest,
handlers::server_branch_list,
handlers::server_branch_create,
handlers::server_branch_delete,
handlers::server_branch_merge,
handlers::server_commit_list,
handlers::server_commit_show,
),
modifiers(&SecurityAddon),
)]
pub struct ApiDoc;
pub fn served_openapi() -> utoipa::openapi::OpenApi {
let mut doc = ApiDoc::openapi();
handlers::nest_paths_under_cluster_prefix(&mut doc);
doc
}
struct SecurityAddon;
impl utoipa::Modify for SecurityAddon {
fn modify(&self, openapi: &mut utoipa::openapi::OpenApi) {
openapi
.components
.get_or_insert_with(Default::default)
.add_security_scheme(
"bearer_token",
SecurityScheme::Http(Http::new(HttpAuthScheme::Bearer)),
);
}
}
const DEFAULT_REQUEST_BODY_LIMIT_BYTES: usize = 1_048_576;
const INGEST_REQUEST_BODY_LIMIT_BYTES: usize = 32 * 1024 * 1024;
const SERVER_VERSION: &str = env!("CARGO_PKG_VERSION");
const SERVER_SOURCE_VERSION: Option<&str> = option_env!("OMNIGRAPH_SOURCE_VERSION");
#[derive(Debug, Clone)]
pub struct ServerConfig {
pub mode: ServerConfigMode,
pub bind: String,
pub allow_unauthenticated: bool,
pub require_all_graphs: bool,
}
#[derive(Debug, Clone)]
pub enum ServerConfigMode {
Multi {
graphs: Vec<GraphStartupConfig>,
config_path: PathBuf,
server_policy: Option<PolicySource>,
},
}
#[derive(Debug, Clone)]
pub enum PolicySource {
File(PathBuf),
Inline(String),
}
#[derive(Debug, Clone)]
pub struct GraphStartupConfig {
pub graph_id: String,
pub uri: String,
pub policy: Option<PolicySource>,
pub embedding: Option<omnigraph::embedding::EmbeddingConfig>,
pub queries: QueryRegistry,
}
#[derive(Clone)]
pub struct GraphRouting {
pub registry: Arc<GraphRegistry>,
pub config_path: Option<PathBuf>,
}
#[derive(Clone)]
pub struct AppState {
routing: GraphRouting,
workload: Arc<workload::WorkloadController>,
bearer_tokens: Arc<[(BearerTokenHash, Arc<str>)]>,
server_policy: Option<Arc<PolicyEngine>>,
}
struct ExportStreamWriter {
sender: mpsc::UnboundedSender<std::result::Result<Bytes, io::Error>>,
}
impl Write for ExportStreamWriter {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.sender
.send(Ok(Bytes::copy_from_slice(buf)))
.map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "export stream closed"))?;
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
#[derive(Debug)]
pub struct ApiError {
status: StatusCode,
code: ErrorCode,
message: String,
merge_conflicts: Vec<api::MergeConflictOutput>,
manifest_conflict: Option<api::ManifestConflictOutput>,
}
impl AppState {
pub fn new_single(
uri: String,
db: Omnigraph,
bearer_tokens: Vec<(String, String)>,
policy_engine: Option<PolicyEngine>,
workload: workload::WorkloadController,
) -> Self {
let bearer_tokens = hash_bearer_tokens(bearer_tokens);
let per_graph_policy = policy_engine.map(Arc::new);
Self::build_single_mode(
uri,
db,
bearer_tokens,
per_graph_policy,
Arc::new(workload),
None,
)
}
fn new_single_with_queries(
uri: String,
db: Omnigraph,
bearer_tokens: Vec<(String, String)>,
policy_engine: Option<PolicyEngine>,
workload: workload::WorkloadController,
queries: Option<Arc<QueryRegistry>>,
) -> Self {
let bearer_tokens = hash_bearer_tokens(bearer_tokens);
let per_graph_policy = policy_engine.map(Arc::new);
Self::build_single_mode(
uri,
db,
bearer_tokens,
per_graph_policy,
Arc::new(workload),
queries,
)
}
pub fn new(uri: String, db: Omnigraph) -> Self {
Self::new_single(
uri,
db,
Vec::new(),
None,
workload::WorkloadController::from_env(),
)
}
pub fn new_with_bearer_token(uri: String, db: Omnigraph, bearer_token: Option<String>) -> Self {
let bearer_tokens = normalize_bearer_token(bearer_token)
.into_iter()
.map(|token| ("default".to_string(), token))
.collect();
Self::new_with_bearer_tokens(uri, db, bearer_tokens)
}
pub fn new_with_bearer_tokens(
uri: String,
db: Omnigraph,
bearer_tokens: Vec<(String, String)>,
) -> Self {
Self::new_single(
uri,
db,
bearer_tokens,
None,
workload::WorkloadController::from_env(),
)
}
pub fn new_with_bearer_tokens_and_policy(
uri: String,
db: Omnigraph,
bearer_tokens: Vec<(String, String)>,
policy_engine: Option<PolicyEngine>,
) -> Self {
Self::new_single(
uri,
db,
bearer_tokens,
policy_engine,
workload::WorkloadController::from_env(),
)
}
pub fn new_with_workload(
uri: String,
db: Omnigraph,
bearer_tokens: Vec<(String, String)>,
workload: workload::WorkloadController,
) -> Self {
Self::new_single(uri, db, bearer_tokens, None, workload)
}
pub async fn open(uri: impl Into<String>) -> Result<Self> {
Self::open_with_bearer_token(uri, None).await
}
pub async fn open_with_bearer_token(
uri: impl Into<String>,
bearer_token: Option<String>,
) -> Result<Self> {
let bearer_tokens = normalize_bearer_token(bearer_token)
.into_iter()
.map(|token| ("default".to_string(), token))
.collect();
Self::open_with_bearer_tokens(uri, bearer_tokens).await
}
pub async fn open_with_bearer_tokens(
uri: impl Into<String>,
bearer_tokens: Vec<(String, String)>,
) -> Result<Self> {
let uri = normalize_root_uri(&uri.into()).wrap_err("normalize graph URI")?;
let db = Omnigraph::open(&uri).await?;
Ok(Self::new_with_bearer_tokens(uri, db, bearer_tokens))
}
pub async fn open_with_bearer_tokens_and_policy(
uri: impl Into<String>,
bearer_tokens: Vec<(String, String)>,
policy_file: Option<&PathBuf>,
) -> Result<Self> {
Self::open_single_with_queries(uri, bearer_tokens, policy_file, QueryRegistry::default())
.await
}
pub async fn open_single_with_queries(
uri: impl Into<String>,
bearer_tokens: Vec<(String, String)>,
policy_file: Option<&PathBuf>,
queries: QueryRegistry,
) -> Result<Self> {
Self::open_single_with_queries_for_graph_id(uri, bearer_tokens, policy_file, queries, None)
.await
}
async fn open_single_with_queries_for_graph_id(
uri: impl Into<String>,
bearer_tokens: Vec<(String, String)>,
policy_file: Option<&PathBuf>,
queries: QueryRegistry,
graph_id: Option<String>,
) -> Result<Self> {
let uri = normalize_root_uri(&uri.into()).wrap_err("normalize graph URI")?;
let graph_id = graph_id.unwrap_or_else(|| uri.clone());
let db = Omnigraph::open(&uri).await?;
let registry = validate_and_attach(queries, &db.catalog(), &graph_id)?;
let policy_engine = match policy_file {
Some(path) => Some(PolicyEngine::load_graph(path, &graph_id)?),
None => None,
};
Ok(Self::new_single_with_queries(
uri,
db,
bearer_tokens,
policy_engine,
workload::WorkloadController::from_env(),
registry,
))
}
fn build_single_mode(
uri: String,
db: Omnigraph,
bearer_tokens: Arc<[(BearerTokenHash, Arc<str>)]>,
policy_engine: Option<Arc<PolicyEngine>>,
workload: Arc<workload::WorkloadController>,
queries: Option<Arc<QueryRegistry>>,
) -> Self {
let db = if let Some(policy) = policy_engine.as_ref() {
let checker = Arc::clone(policy) as Arc<dyn omnigraph_policy::PolicyChecker>;
db.with_policy(checker)
} else {
db
};
let uri = normalize_root_uri(&uri).unwrap_or(uri);
let graph_id = GraphId::try_from("default").expect("'default' is a valid GraphId");
let key = GraphKey::cluster(graph_id);
let handle = Arc::new(GraphHandle {
key,
uri,
engine: Arc::new(db),
policy: policy_engine,
queries,
});
let registry = Arc::new(
GraphRegistry::from_handles(vec![handle])
.expect("a single handle never collides on graph id"),
);
Self {
routing: GraphRouting {
registry,
config_path: None,
},
workload,
bearer_tokens,
server_policy: None,
}
}
pub fn new_multi(
handles: Vec<Arc<GraphHandle>>,
bearer_tokens: Vec<(String, String)>,
server_policy: Option<PolicyEngine>,
workload: workload::WorkloadController,
config_path: Option<PathBuf>,
) -> std::result::Result<Self, InsertError> {
let bearer_tokens = hash_bearer_tokens(bearer_tokens);
let registry = Arc::new(GraphRegistry::from_handles(handles)?);
Ok(Self {
routing: GraphRouting {
registry,
config_path,
},
workload: Arc::new(workload),
bearer_tokens,
server_policy: server_policy.map(Arc::new),
})
}
pub fn routing(&self) -> &GraphRouting {
&self.routing
}
fn requires_bearer_auth(&self) -> bool {
if !self.bearer_tokens.is_empty() {
return true;
}
if self.server_policy.is_some() {
return true;
}
self.routing.registry.snapshot_ref().any_per_graph_policy
}
fn authenticate_bearer_token(&self, provided_token: &str) -> Option<ResolvedActor> {
let provided_hash = hash_bearer_token(provided_token);
let mut matched: Option<Arc<str>> = None;
for (hash, actor) in self.bearer_tokens.iter() {
if bool::from(hash.ct_eq(&provided_hash)) && matched.is_none() {
matched = Some(Arc::clone(actor));
}
}
matched.map(ResolvedActor::cluster_static)
}
}
fn hash_bearer_tokens(bearer_tokens: Vec<(String, String)>) -> Arc<[(BearerTokenHash, Arc<str>)]> {
let tokens: Vec<(BearerTokenHash, Arc<str>)> = bearer_tokens
.into_iter()
.map(|(actor, token)| (hash_bearer_token(&token), Arc::<str>::from(actor)))
.collect();
Arc::from(tokens)
}
impl ApiError {
pub fn unauthorized(message: impl Into<String>) -> Self {
Self {
status: StatusCode::UNAUTHORIZED,
code: ErrorCode::Unauthorized,
message: message.into(),
merge_conflicts: Vec::new(),
manifest_conflict: None,
}
}
pub fn forbidden(message: impl Into<String>) -> Self {
Self {
status: StatusCode::FORBIDDEN,
code: ErrorCode::Forbidden,
message: message.into(),
merge_conflicts: Vec::new(),
manifest_conflict: None,
}
}
pub fn bad_request(message: impl Into<String>) -> Self {
Self {
status: StatusCode::BAD_REQUEST,
code: ErrorCode::BadRequest,
message: message.into(),
merge_conflicts: Vec::new(),
manifest_conflict: None,
}
}
pub fn not_found(message: impl Into<String>) -> Self {
Self {
status: StatusCode::NOT_FOUND,
code: ErrorCode::NotFound,
message: message.into(),
merge_conflicts: Vec::new(),
manifest_conflict: None,
}
}
pub fn method_not_allowed(message: impl Into<String>) -> Self {
Self {
status: StatusCode::METHOD_NOT_ALLOWED,
code: ErrorCode::MethodNotAllowed,
message: message.into(),
merge_conflicts: Vec::new(),
manifest_conflict: None,
}
}
pub fn conflict(message: impl Into<String>) -> Self {
Self {
status: StatusCode::CONFLICT,
code: ErrorCode::Conflict,
message: message.into(),
merge_conflicts: Vec::new(),
manifest_conflict: None,
}
}
pub fn internal(message: impl Into<String>) -> Self {
Self {
status: StatusCode::INTERNAL_SERVER_ERROR,
code: ErrorCode::Internal,
message: message.into(),
merge_conflicts: Vec::new(),
manifest_conflict: None,
}
}
pub fn too_many_requests(message: impl Into<String>) -> Self {
Self {
status: StatusCode::TOO_MANY_REQUESTS,
code: ErrorCode::TooManyRequests,
message: message.into(),
merge_conflicts: Vec::new(),
manifest_conflict: None,
}
}
pub fn from_workload_reject(reject: workload::RejectReason) -> Self {
match reject {
workload::RejectReason::InFlightCountExceeded { .. }
| workload::RejectReason::ByteBudgetExceeded { .. } => {
Self::too_many_requests(reject.to_string())
}
}
}
fn merge_conflict(conflicts: Vec<api::MergeConflictOutput>) -> Self {
Self {
status: StatusCode::CONFLICT,
code: ErrorCode::Conflict,
message: summarize_merge_conflicts(&conflicts),
merge_conflicts: conflicts,
manifest_conflict: None,
}
}
fn manifest_version_conflict(message: String, details: api::ManifestConflictOutput) -> Self {
Self {
status: StatusCode::CONFLICT,
code: ErrorCode::Conflict,
message,
merge_conflicts: Vec::new(),
manifest_conflict: Some(details),
}
}
fn from_omni(err: OmniError) -> Self {
match err {
OmniError::Compiler(err) => Self::bad_request(err.to_string()),
OmniError::DataFusion(message) => Self::bad_request(format!("query: {message}")),
OmniError::Manifest(err) => match err.kind {
ManifestErrorKind::BadRequest => Self::bad_request(err.message),
ManifestErrorKind::NotFound => Self::not_found(err.message),
ManifestErrorKind::Conflict => match err.details {
Some(ManifestConflictDetails::ExpectedVersionMismatch {
table_key,
expected,
actual,
}) => Self::manifest_version_conflict(
err.message,
api::ManifestConflictOutput {
table_key,
expected,
actual,
},
),
_ => Self::conflict(err.message),
},
ManifestErrorKind::Internal => Self::internal(err.message),
},
OmniError::MergeConflicts(conflicts) => Self::merge_conflict(
conflicts
.iter()
.map(api::MergeConflictOutput::from)
.collect(),
),
OmniError::Lance(message) => Self::internal(format!("storage: {message}")),
OmniError::Io(err) => Self::internal(format!("io: {err}")),
OmniError::Policy(message) => Self::forbidden(message),
err @ OmniError::AlreadyInitialized { .. } => Self::conflict(err.to_string()),
}
}
}
fn summarize_merge_conflicts(conflicts: &[api::MergeConflictOutput]) -> String {
if conflicts.is_empty() {
return "merge conflicts".to_string();
}
let preview: Vec<String> = conflicts
.iter()
.take(3)
.map(|conflict| match conflict.row_id.as_deref() {
Some(row_id) => format!(
"{}:{} ({})",
conflict.table_key,
row_id,
conflict.kind.as_str()
),
None => format!("{} ({})", conflict.table_key, conflict.kind.as_str()),
})
.collect();
let suffix = if conflicts.len() > preview.len() {
format!("; and {} more", conflicts.len() - preview.len())
} else {
String::new()
};
format!("merge conflicts: {}{}", preview.join("; "), suffix)
}
const RETRY_AFTER_SECONDS: &str = "60";
impl IntoResponse for ApiError {
fn into_response(self) -> Response {
let mut headers = axum::http::HeaderMap::new();
if matches!(self.code, ErrorCode::TooManyRequests) {
headers.insert(
axum::http::header::RETRY_AFTER,
axum::http::HeaderValue::from_static(RETRY_AFTER_SECONDS),
);
}
(
self.status,
headers,
Json(ErrorOutput {
error: self.message,
code: Some(self.code),
merge_conflicts: self.merge_conflicts,
manifest_conflict: self.manifest_conflict,
}),
)
.into_response()
}
}
pub fn init_tracing() {
let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
let _ = tracing_subscriber::fmt().with_env_filter(filter).try_init();
}
fn log_registry_warnings(label: &str, report: &queries::CheckReport) {
for warning in &report.warnings {
warn!(graph = label, query = %warning.query, "stored query: {}", warning.message);
}
}
fn validate_registry_against_catalog(
registry: &QueryRegistry,
catalog: &Catalog,
label: &str,
) -> omnigraph::error::Result<()> {
let report = check(registry, catalog);
if report.has_breakages() {
return Err(OmniError::manifest(format_check_breakages(label, &report)));
}
log_registry_warnings(label, &report);
Ok(())
}
fn validate_and_attach(
queries: QueryRegistry,
catalog: &Catalog,
label: &str,
) -> Result<Option<Arc<QueryRegistry>>> {
validate_registry_against_catalog(&queries, catalog, label)
.map_err(|err| color_eyre::eyre::eyre!(err.to_string()))?;
Ok(if queries.is_empty() {
None
} else {
Some(Arc::new(queries))
})
}
pub fn build_app(state: AppState) -> Router {
let per_graph_protected = Router::new()
.route("/snapshot", get(server_snapshot))
.route("/export", post(server_export))
.route(
"/read",
post({
#[allow(deprecated)]
server_read
}),
)
.route("/query", post(server_query))
.route(
"/change",
post({
#[allow(deprecated)]
server_change
}),
)
.route("/mutate", post(server_mutate))
.route("/queries", get(server_list_queries))
.route("/queries/{name}", post(server_invoke_query))
.route("/schema", get(server_schema_get))
.route("/schema/apply", post(server_schema_apply))
.route(
"/load",
post(server_load).layer(DefaultBodyLimit::max(INGEST_REQUEST_BODY_LIMIT_BYTES)),
)
.route(
"/ingest",
post({
#[allow(deprecated)]
server_ingest
})
.layer(DefaultBodyLimit::max(INGEST_REQUEST_BODY_LIMIT_BYTES)),
)
.route(
"/branches",
get(server_branch_list).post(server_branch_create),
)
.route("/branches/{branch}", delete(server_branch_delete))
.route("/branches/merge", post(server_branch_merge))
.route("/commits", get(server_commit_list))
.route("/commits/{commit_id}", get(server_commit_show))
.route_layer(middleware::from_fn_with_state(
state.clone(),
resolve_graph_handle,
))
.route_layer(middleware::from_fn_with_state(
state.clone(),
require_bearer_auth,
));
let management = Router::new()
.route("/graphs", get(server_graphs_list))
.route_layer(middleware::from_fn_with_state(
state.clone(),
require_bearer_auth,
));
let protected: Router<AppState> = Router::new()
.nest("/graphs/{graph_id}", per_graph_protected)
.merge(management);
Router::new()
.route("/healthz", get(server_health))
.route("/openapi.json", get(server_openapi))
.merge(protected)
.layer(DefaultBodyLimit::max(DEFAULT_REQUEST_BODY_LIMIT_BYTES))
.layer(TraceLayer::new_for_http())
.with_state(state)
}
pub async fn serve(config: ServerConfig) -> Result<()> {
let token_source = resolve_token_source().await?;
info!(source = token_source.name(), "loaded bearer token source");
let tokens = token_source.load().await?;
let has_policy_configured = match &config.mode {
ServerConfigMode::Multi {
graphs,
server_policy,
..
} => server_policy.is_some() || graphs.iter().any(|g| g.policy.is_some()),
};
let runtime_state = classify_server_runtime_state(
!tokens.is_empty(),
has_policy_configured,
config.allow_unauthenticated,
)?;
match runtime_state {
ServerRuntimeState::Open => warn!(
"running with --unauthenticated: no bearer tokens, no policy file, all \
requests permitted. This is for local dev only — do not expose to a \
network you don't fully trust."
),
ServerRuntimeState::DefaultDeny => warn!(
"bearer tokens are configured but no policy file is set — running in \
default-deny mode (only `read` actions are permitted for authenticated \
actors). Configure a graph or cluster policy bundle in the cluster config, \
run `omnigraph cluster apply`, and restart to enable Cedar rules."
),
ServerRuntimeState::PolicyEnabled => {}
}
let bind = config.bind.clone();
let state = match config.mode {
ServerConfigMode::Multi {
graphs,
config_path,
server_policy,
} => {
info!(
bind = %bind,
mode = "cluster",
graph_count = graphs.len(),
config = %config_path.display(),
"serving omnigraph"
);
open_multi_graph_state(
graphs,
tokens,
server_policy.as_ref(),
config_path,
config.require_all_graphs,
)
.await?
}
};
let listener = TcpListener::bind(&bind).await?;
axum::serve(listener, build_app(state))
.with_graceful_shutdown(shutdown_signal())
.await?;
Ok(())
}
fn load_graph_policy(source: &PolicySource, graph_id: &str) -> Result<PolicyEngine> {
match source {
PolicySource::File(path) => Ok(PolicyEngine::load_graph(path, graph_id)?),
PolicySource::Inline(text) => Ok(PolicyEngine::load_graph_from_source(text, graph_id)?),
}
}
pub async fn open_multi_graph_state(
graphs: Vec<GraphStartupConfig>,
tokens: Vec<(String, String)>,
server_policy_source: Option<&PolicySource>,
config_path: PathBuf,
require_all_graphs: bool,
) -> Result<AppState> {
use futures::StreamExt;
if graphs.is_empty() {
bail!("multi-graph mode requires at least one graph in the `graphs:` map");
}
let server_policy = match server_policy_source {
Some(PolicySource::File(path)) => Some(PolicyEngine::load_server(path)?),
Some(PolicySource::Inline(source)) => Some(PolicyEngine::load_server_from_source(source)?),
None => None,
};
let configured_graphs = graphs.len();
let results = futures::stream::iter(graphs.into_iter())
.map(|cfg| async move {
let graph_id = cfg.graph_id.clone();
open_single_graph(cfg).await.map_err(|err| (graph_id, err))
})
.buffer_unordered(4)
.collect::<Vec<_>>()
.await;
let mut handles = Vec::new();
let mut failed = 0usize;
for result in results {
match result {
Ok(handle) => handles.push(handle),
Err((graph_id, err)) => {
failed += 1;
warn!(
graph_id = %graph_id,
error = %err,
"graph quarantined during startup"
);
}
}
}
if require_all_graphs && failed > 0 {
bail!(
"strict multi-graph startup requires every graph to open ({} configured, {} failed)",
configured_graphs,
failed
);
}
if handles.is_empty() {
bail!(
"no healthy graphs opened from multi-graph startup config ({} configured, {} failed)",
configured_graphs,
failed
);
}
let workload = workload::WorkloadController::from_env();
let state = AppState::new_multi(handles, tokens, server_policy, workload, Some(config_path))
.map_err(|err| color_eyre::eyre::eyre!("multi-graph registry: {err}"))?;
Ok(state)
}
async fn open_single_graph(cfg: GraphStartupConfig) -> Result<Arc<GraphHandle>> {
let graph_id = GraphId::try_from(cfg.graph_id.clone())
.map_err(|err| color_eyre::eyre::eyre!("graph id '{}': {err}", cfg.graph_id))?;
let uri = normalize_root_uri(&cfg.uri)
.wrap_err_with(|| format!("normalize URI for graph '{}'", cfg.graph_id))?;
let db = Omnigraph::open(&uri)
.await
.map_err(|err| color_eyre::eyre::eyre!("open graph '{}' at {}: {err}", graph_id, uri))?;
let db = if let Some(embedding) = cfg.embedding {
db.with_embedding_config(Arc::new(embedding))
} else {
db
};
let queries = validate_and_attach(cfg.queries, &db.catalog(), graph_id.as_str())?;
let (policy_arc, db) = match &cfg.policy {
Some(source) => {
let policy = load_graph_policy(source, graph_id.as_str())?;
let policy_arc: Arc<PolicyEngine> = Arc::new(policy);
let checker = Arc::clone(&policy_arc) as Arc<dyn omnigraph_policy::PolicyChecker>;
(Some(policy_arc), db.with_policy(checker))
}
None => (None, db),
};
Ok(Arc::new(GraphHandle {
key: GraphKey::cluster(graph_id),
uri,
engine: Arc::new(db),
policy: policy_arc,
queries,
}))
}
async fn shutdown_signal() {
if let Err(err) = tokio::signal::ctrl_c().await {
error!(error = %err, "failed to install ctrl-c handler");
return;
}
info!("shutdown signal received");
}