use crate::protocol::*;
use somatize_core::cache::CacheStore;
use somatize_core::event::Event;
use somatize_core::filter::Filter;
use somatize_core::store::{DataStore, LocalDataStore};
use somatize_core::value::Value;
use somatize_runtime::{EventBus, FilterLibrary, MemoryCache, Runner};
use std::sync::Arc;
use std::time::Instant;
pub struct Worker {
pub id: WorkerId,
pub capabilities: Capabilities,
event_bus: Arc<EventBus>,
cache: Arc<dyn CacheStore>,
filters: FilterLibrary,
data_store: Option<Arc<dyn DataStore>>,
temp_store: Arc<LocalDataStore>,
env_manager: crate::env_manager::EnvManager,
}
impl Worker {
pub fn new(id: impl Into<String>, capabilities: Capabilities) -> Self {
let worker_id: String = id.into();
let temp_path = std::env::temp_dir().join(format!("soma-uploads-{worker_id}"));
let temp_store = LocalDataStore::new(temp_path);
let env_path = std::env::temp_dir().join(format!("soma-envs-{worker_id}"));
Self {
id: worker_id,
capabilities,
event_bus: Arc::new(EventBus::new(256)),
cache: Arc::new(MemoryCache::default()),
filters: FilterLibrary::new(),
data_store: None,
temp_store: Arc::new(temp_store),
env_manager: crate::env_manager::EnvManager::new(
env_path,
crate::env_manager::EnvType::Venv,
),
}
}
pub fn with_cache(mut self, cache: Arc<dyn CacheStore>) -> Self {
self.cache = cache;
self
}
pub fn with_data_store(mut self, store: Arc<dyn DataStore>) -> Self {
self.data_store = Some(store);
self
}
pub fn with_temp_dir(mut self, path: std::path::PathBuf) -> Self {
self.temp_store = Arc::new(LocalDataStore::new(path));
self
}
pub fn temp_store(&self) -> &Arc<LocalDataStore> {
&self.temp_store
}
pub fn register_filter(&mut self, node_id: impl Into<String>, filter: Box<dyn Filter>) {
self.filters.register(node_id, filter);
}
pub fn get_filter(&self, node_id: &str) -> Option<Arc<dyn Filter>> {
self.filters.get(node_id)
}
pub fn get_filter_state(&self, node_id: &str) -> Value {
self.filters
.get_state(node_id)
.cloned()
.unwrap_or(Value::Empty)
}
pub fn set_filter_state(&mut self, node_id: &str, state: Value) {
self.filters.set_state(node_id, state);
}
pub fn wrap_output(&self, output: Value) -> OutputDelivery {
let size = serde_json::to_vec(&output).map(|v| v.len()).unwrap_or(0);
if size >= somatize_core::store::INLINE_THRESHOLD_BYTES {
let key = somatize_core::cache::CacheKey::hash_data(
&serde_json::to_vec(&output).unwrap_or_default(),
);
if let Ok(data_ref) = self.temp_store.put(&key, &output) {
return OutputDelivery::Reference { data_ref };
}
}
OutputDelivery::Inline { value: output }
}
pub fn subscribe(&self) -> tokio::sync::broadcast::Receiver<Event> {
self.event_bus.subscribe()
}
pub fn registration_message(&self) -> WorkerToCoordinator {
WorkerToCoordinator::Register {
worker_id: self.id.clone(),
capabilities: self.capabilities.clone(),
}
}
pub fn execute_plan(&mut self, plan: &SerializedPlan) -> PlanResult {
let start = Instant::now();
let all_reqs: Vec<String> = plan
.filters
.iter()
.flat_map(|sf| sf.requirements.iter().cloned())
.collect::<std::collections::HashSet<_>>()
.into_iter()
.collect();
let python_path = if all_reqs.is_empty() {
"python3".to_string()
} else {
let reqs_str = all_reqs.join("\n");
match self.env_manager.ensure_env(&plan.plan_id, &reqs_str) {
Ok(path) => {
tracing::info!("Using venv for plan {}: {:?}", plan.plan_id, path);
path.to_string_lossy().to_string()
}
Err(e) => {
tracing::warn!("Failed to create venv, falling back to system python: {e}");
"python3".to_string()
}
}
};
let filter_specs: Vec<(String, Vec<u8>, bool)> = plan
.filters
.iter()
.map(|sf| (sf.node_id.clone(), sf.pickled_filter.clone(), sf.trainable))
.collect();
if !filter_specs.is_empty() {
let process = Arc::new(std::sync::Mutex::new(
crate::python_process::PythonProcess::spawn(&python_path, &filter_specs)
.map_err(|e| {
tracing::error!("Failed to spawn Python process: {e}");
e
})
.expect("PythonProcess spawn failed"),
));
for sf in &plan.filters {
let filter = Box::new(crate::python_process::SubprocessFilter::new(
process.clone(),
sf.node_id.clone(),
sf.trainable,
));
self.filters.register(&sf.node_id, filter);
if let Some(state) = &sf.state {
self.filters.set_state(&sf.node_id, state.clone());
}
}
}
let input_value = plan
.input
.as_ref()
.map(|src| src.resolve(self.data_store.as_deref(), &self.temp_store));
if let Some(InputSource::Reference { data_ref }) = &plan.input
&& let Some(store) = self.data_store.clone()
&& let Ok(meta) = store.meta(data_ref)
&& meta.total_rows > 1024
{
return self.execute_streamed_from_store(plan, &store, data_ref, &meta, start);
}
let runner = somatize_runtime::LocalRunner;
let x = input_value.unwrap_or(Value::Empty);
let result = match &plan.mode {
ExecutionMode::Fit { y } => runner
.fit(
&plan.plan,
&self.filters,
self.cache.as_ref(),
&self.event_bus,
&x,
y.as_ref(),
)
.map(|(output, all_outputs)| {
let mut trained_states = std::collections::HashMap::new();
for (key, value) in &all_outputs {
if let Some(node_id) = key.strip_prefix("__state_") {
self.filters.set_state(node_id, value.clone());
trained_states.insert(node_id.to_string(), value.clone());
}
}
(output, trained_states)
}),
ExecutionMode::Forward => runner
.forward(
&plan.plan,
&self.filters,
self.cache.as_ref(),
&self.event_bus,
&x,
)
.map(|output| (output, std::collections::HashMap::new())),
};
match result {
Ok((output, states)) => PlanResult::Success {
output: self.wrap_output(output),
duration_ms: start.elapsed().as_millis() as u64,
states,
},
Err(e) => PlanResult::Failed {
error: e.to_string(),
duration_ms: start.elapsed().as_millis() as u64,
},
}
}
fn execute_streamed_from_store(
&mut self,
plan: &SerializedPlan,
store: &Arc<dyn DataStore>,
data_ref: &somatize_core::store::DataRef,
meta: &somatize_core::store::StoreMeta,
start: Instant,
) -> PlanResult {
use somatize_runtime::executors::stream::{FittedFilter, StreamExecutor};
let node_ids: Vec<String> = plan.plan.node_ids().into_iter().map(String::from).collect();
let fitted: Vec<FittedFilter> = node_ids
.iter()
.filter_map(|id| {
let filter = self.filters.get(id)?;
let state = self.filters.get_state(id).cloned().unwrap_or(Value::Empty);
Some(FittedFilter {
name: id.clone(),
filter,
state,
})
})
.collect();
let mut executor = StreamExecutor::new(fitted);
let chunk_size = 1024;
let run_id = format!("worker_stream_{}", plan.plan_id);
self.event_bus.emit(Event::RunStarted {
run_id: run_id.clone(),
plan_summary: somatize_core::event::PlanSummary {
total_nodes: node_ids.len(),
cached_nodes: 0,
parallel_branches: 0,
},
});
let mut last_output = Value::Empty;
let total = meta.total_rows;
let mut chunk_idx = 0;
for row_start in (0..total).step_by(chunk_size) {
let len = chunk_size.min(total - row_start);
let chunk = match store.get_rows(data_ref, row_start, len) {
Ok(c) => c,
Err(e) => {
return PlanResult::Failed {
error: format!("get_rows({row_start}..{}): {e}", row_start + len),
duration_ms: start.elapsed().as_millis() as u64,
};
}
};
match executor.process_chunk(chunk) {
Ok(Some(output)) => last_output = output,
Ok(None) => {} Err(e) => {
return PlanResult::Failed {
error: format!("stream chunk {chunk_idx}: {e}"),
duration_ms: start.elapsed().as_millis() as u64,
};
}
}
chunk_idx += 1;
}
match executor.flush() {
Ok(Some(output)) => last_output = output,
Ok(None) => {}
Err(e) => {
return PlanResult::Failed {
error: format!("stream flush: {e}"),
duration_ms: start.elapsed().as_millis() as u64,
};
}
}
tracing::info!(
"Streamed {chunk_idx} chunks ({total} rows) in {}ms",
start.elapsed().as_millis()
);
PlanResult::Success {
output: self.wrap_output(last_output),
duration_ms: start.elapsed().as_millis() as u64,
states: std::collections::HashMap::new(),
}
}
pub fn matches_target(&self, target: &somatize_core::filter::RemoteTarget) -> bool {
match target {
somatize_core::filter::RemoteTarget::WorkerId(id) => &self.id == id,
somatize_core::filter::RemoteTarget::Tag(tag) => self.capabilities.tags.contains(tag),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use somatize_compiler::ExecutionPlan;
use somatize_core::cache::CacheKey;
use somatize_core::error::Result as SomaResult;
use somatize_core::filter::{FilterKind, FilterMeta, StreamMode};
use somatize_core::value::Value;
struct TestDoubler;
impl Filter for TestDoubler {
fn config_hash(&self) -> CacheKey {
CacheKey::from_parts(&[b"TestDoubler"])
}
fn fit(&self, _x: &Value, _y: Option<&Value>) -> SomaResult<Value> {
Ok(Value::Empty)
}
fn forward(&self, x: &Value, _state: &Value) -> SomaResult<Value> {
match x {
Value::Tensor { values, shape } => {
let doubled: Vec<f64> = values.iter().map(|v| v * 2.0).collect();
Ok(Value::tensor(doubled, shape.clone()))
}
_ => Ok(x.clone()),
}
}
fn meta(&self) -> FilterMeta {
FilterMeta {
name: "TestDoubler".into(),
kind: FilterKind::Stateless,
cacheable: true,
differentiable: true,
stream_mode: StreamMode::FixedState,
distribution: somatize_core::filter::Distribution::Local,
input_schema: None,
output_schema: None,
}
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
fn make_worker() -> Worker {
Worker::new(
"test_worker",
Capabilities {
cpu_cores: 4,
ram_bytes: 8_000_000_000,
gpus: vec![],
python_envs: vec![],
tags: vec!["cpu".into(), "test".into()],
},
)
}
#[test]
fn worker_registration() {
let worker = make_worker();
let msg = worker.registration_message();
if let WorkerToCoordinator::Register {
worker_id,
capabilities,
} = msg
{
assert_eq!(worker_id, "test_worker");
assert_eq!(capabilities.cpu_cores, 4);
} else {
panic!("wrong message type");
}
}
#[test]
fn worker_executes_plan_successfully() {
let mut worker = make_worker();
worker.register_filter("doubler", Box::new(TestDoubler));
let plan = SerializedPlan {
plan_id: "p_001".into(),
plan: ExecutionPlan::Execute {
node_id: "doubler".into(),
},
input: Some(crate::protocol::InputSource::Inline {
value: Value::tensor(vec![1.0, 2.0, 3.0], vec![3]),
}),
filters: vec![],
mode: ExecutionMode::default(),
metadata: serde_json::json!({}),
};
let result = worker.execute_plan(&plan);
if let PlanResult::Success {
output,
duration_ms,
..
} = result
{
let value = match output {
OutputDelivery::Inline { value } => value,
_ => panic!("expected inline output"),
};
let (data, _) = value.as_tensor().unwrap();
assert_eq!(data, &[2.0, 4.0, 6.0]);
assert!(duration_ms < 1000);
} else {
panic!("expected success, got: {result:?}");
}
}
#[test]
fn worker_handles_missing_filter() {
let mut worker = make_worker();
let plan = SerializedPlan {
plan_id: "p_002".into(),
plan: ExecutionPlan::Execute {
node_id: "nonexistent".into(),
},
input: None,
filters: vec![],
mode: ExecutionMode::default(),
metadata: serde_json::json!({}),
};
let result = worker.execute_plan(&plan);
assert!(matches!(result, PlanResult::Failed { .. }));
}
#[test]
fn worker_matches_target_by_id() {
let worker = make_worker();
assert!(
worker.matches_target(&somatize_core::filter::RemoteTarget::WorkerId(
"test_worker".into()
))
);
assert!(
!worker.matches_target(&somatize_core::filter::RemoteTarget::WorkerId(
"other".into()
))
);
}
#[test]
fn worker_matches_target_by_tag() {
let worker = make_worker();
assert!(worker.matches_target(&somatize_core::filter::RemoteTarget::Tag("cpu".into())));
assert!(worker.matches_target(&somatize_core::filter::RemoteTarget::Tag("test".into())));
assert!(!worker.matches_target(&somatize_core::filter::RemoteTarget::Tag("gpu".into())));
}
#[test]
fn worker_executes_sequence() {
let mut worker = make_worker();
worker.register_filter("d1", Box::new(TestDoubler));
worker.register_filter("d2", Box::new(TestDoubler));
let plan = SerializedPlan {
plan_id: "p_003".into(),
plan: ExecutionPlan::Sequence(vec![
ExecutionPlan::Execute {
node_id: "d1".into(),
},
ExecutionPlan::Execute {
node_id: "d2".into(),
},
]),
input: Some(crate::protocol::InputSource::Inline {
value: Value::tensor(vec![5.0], vec![1]),
}),
filters: vec![],
mode: ExecutionMode::default(),
metadata: serde_json::json!({}),
};
let result = worker.execute_plan(&plan);
if let PlanResult::Success { output, .. } = result {
let value = match output {
OutputDelivery::Inline { value } => value,
_ => panic!("expected inline output"),
};
let (data, _) = value.as_tensor().unwrap();
assert_eq!(data, &[20.0]); } else {
panic!("expected success");
}
}
#[test]
fn worker_emits_events() {
let mut worker = make_worker();
worker.register_filter("doubler", Box::new(TestDoubler));
let mut rx = worker.subscribe();
let plan = SerializedPlan {
plan_id: "p_004".into(),
plan: ExecutionPlan::Execute {
node_id: "doubler".into(),
},
input: Some(crate::protocol::InputSource::Inline {
value: Value::tensor(vec![1.0], vec![1]),
}),
filters: vec![],
mode: ExecutionMode::default(),
metadata: serde_json::json!({}),
};
worker.execute_plan(&plan);
let mut events = Vec::new();
while let Ok(e) = rx.try_recv() {
events.push(e);
}
assert!(
events
.iter()
.any(|e| matches!(e, Event::NodeStarted { .. }))
);
assert!(
events
.iter()
.any(|e| matches!(e, Event::NodeCompleted { .. }))
);
}
}