use std::sync::Arc;
use vgi_rpc::RpcServer;
use crate::dispatch::Dispatcher;
use crate::function::ScalarFunction;
use crate::protocol::register;
pub const VGI_PROTOCOL_VERSION: &str = "1.0.0";
pub const VGI_PROTOCOL_NAME: &str = "VgiProtocol";
pub struct Worker {
disp: Dispatcher,
server_id: Option<String>,
}
impl Default for Worker {
fn default() -> Self {
Worker::new()
}
}
impl Worker {
pub fn new() -> Self {
let catalog_name =
std::env::var("VGI_WORKER_CATALOG_NAME").unwrap_or_else(|_| "example".to_string());
Worker {
disp: Dispatcher::new(catalog_name),
server_id: None,
}
}
pub fn server_id(mut self, id: impl Into<String>) -> Self {
self.server_id = Some(id.into());
self
}
pub fn register_scalar(&mut self, f: impl ScalarFunction + 'static) {
self.disp.register_scalar(Arc::new(f));
}
pub fn register_table(&mut self, f: impl crate::table_function::TableFunction + 'static) {
self.disp.register_table(Arc::new(f));
}
pub fn register_table_in_out(
&mut self,
f: impl crate::table_in_out::TableInOutFunction + 'static,
) {
self.disp.register_table_in_out(Arc::new(f));
}
pub fn register_buffering(
&mut self,
f: impl crate::buffering::TableBufferingFunction + 'static,
) {
self.disp.register_buffering(Arc::new(f));
}
pub fn register_aggregate(&mut self, f: impl crate::aggregate::AggregateFunction + 'static) {
self.disp.register_aggregate(Arc::new(f));
}
pub fn set_catalog(&mut self, model: crate::catalog::CatalogModel) {
self.disp.set_catalog(model);
}
pub fn register_secondary_catalog(
&mut self,
model: crate::catalog::CatalogModel,
functions: Vec<String>,
) {
self.disp.register_secondary_catalog(model, functions);
}
pub fn register_secret_type(&mut self, spec: crate::catalog::SecretTypeSpec) {
self.disp.register_secret_type(spec);
}
pub fn register_setting(&mut self, spec: crate::catalog::SettingSpec) {
self.disp.register_setting(spec);
}
pub fn build_server(self) -> RpcServer {
let server_id = self
.server_id
.clone()
.unwrap_or_else(|| "vgi-rust-worker".to_string());
let protocol_version = std::env::var("VGI_PROTOCOL_VERSION_OVERRIDE")
.unwrap_or_else(|_| VGI_PROTOCOL_VERSION.to_string());
let mut srv = RpcServer::builder()
.server_id(server_id)
.protocol_name(VGI_PROTOCOL_NAME)
.protocol_version(protocol_version)
.enable_describe(true)
.build();
register::register(&mut srv, Arc::new(self.disp));
srv
}
pub fn run(self) {
let args: Vec<String> = std::env::args().collect();
let server = Arc::new(self.build_server());
if args.iter().any(|a| a == "--http") {
crate::transport::serve_http(server, build_authenticate());
return;
}
if let Some(i) = args.iter().position(|a| a == "--unix") {
let path = args
.get(i + 1)
.expect("--unix requires a socket path")
.clone();
let idle = args
.iter()
.position(|a| a == "--idle-timeout")
.and_then(|j| args.get(j + 1))
.and_then(|s| s.parse::<f64>().ok())
.unwrap_or(300.0);
crate::transport::serve_unix(server, &path, idle);
return;
}
crate::transport::serve_stdio(server);
}
}
fn build_authenticate() -> Option<vgi_rpc::Authenticate> {
let mut tokens: std::collections::HashMap<String, String> = std::collections::HashMap::new();
if let Ok(pairs) = std::env::var("VGI_BEARER_TOKENS") {
for pair in pairs.split(',') {
if let Some((tok, principal)) = pair.split_once('=') {
tokens.insert(tok.trim().to_string(), principal.trim().to_string());
}
}
}
if let Ok(tok) = std::env::var("VGI_TEST_BEARER_TOKEN") {
tokens
.entry(tok)
.or_insert_with(|| "test-principal".to_string());
}
if tokens.is_empty() {
return None;
}
Some(std::sync::Arc::new(
move |req: &vgi_rpc::AuthRequest<'_>| {
let token = req
.header("authorization")
.and_then(|h| {
h.strip_prefix("Bearer ")
.or_else(|| h.strip_prefix("bearer "))
})
.map(|t| t.trim());
match token {
None => Err(vgi_rpc::RpcError::permission_error(
"bearer token required but not provided",
)),
Some(tok) => match tokens.get(tok) {
Some(principal) => Ok(vgi_rpc::AuthContext {
domain: "bearer".to_string(),
authenticated: true,
principal: principal.clone(),
claims: Default::default(),
}),
None => Err(vgi_rpc::RpcError::permission_error(
"bearer token was rejected",
)),
},
}
},
))
}