use std::sync::Arc;
use chrono::Utc;
use tokio_retry::strategy::ExponentialBackoff;
use tokio_retry::Retry;
use tracing::{error, info, warn};
use crate::client::{KeyComputeClient, OllamaClient};
use crate::error::NodeTokenError;
use crate::protocol::ollama::OllamaChatRequest;
use crate::protocol::types::{
NodeTaskCompleteRequest, NodeTaskCompleteResponse, NodeTaskEnvelope, NodeTaskResult,
};
use crate::storage::SessionData;
#[allow(dead_code)] type Result<T> = std::result::Result<T, NodeTokenError>;
#[allow(dead_code)] pub struct TaskExecutor {
client: Arc<KeyComputeClient>,
ollama_client: Arc<OllamaClient>,
session: SessionData,
node_status: Arc<tokio::sync::Mutex<String>>,
server_failure_count: Arc<tokio::sync::Mutex<u32>>,
failure_threshold: Arc<tokio::sync::Mutex<u32>>,
}
impl TaskExecutor {
#[allow(dead_code)] pub fn new(
client: Arc<KeyComputeClient>,
ollama_client: Arc<OllamaClient>,
session: SessionData,
) -> Self {
Self {
client,
ollama_client,
session,
node_status: Arc::new(tokio::sync::Mutex::new("unknown".to_string())),
server_failure_count: Arc::new(tokio::sync::Mutex::new(0)),
failure_threshold: Arc::new(tokio::sync::Mutex::new(3)),
}
}
#[allow(dead_code)] pub async fn execute(&self, envelope: NodeTaskEnvelope) {
let task_id = envelope.task_id;
let lease_id = envelope.lease_id;
let deadline_ms = envelope.deadline_unix_ms;
let grace_until_ms = envelope.complete_grace_until_unix_ms;
info!(
"Executing task: task_id={}, model={}, deadline_ms={}, grace_until_ms={}",
task_id, envelope.model, deadline_ms, grace_until_ms
);
let result = match self.execute_ollama(&envelope).await {
Ok(response) => {
info!("Task {} executed successfully", task_id);
NodeTaskResult::Succeeded { response }
}
Err(e) => {
error!("Task {} execution failed: {}", task_id, e);
NodeTaskResult::Failed {
code: "ollama_error".to_string(),
message: e.to_string(),
}
}
};
self.complete_with_retry(task_id, lease_id, result, deadline_ms, grace_until_ms)
.await;
}
#[allow(dead_code)] async fn execute_ollama(
&self,
envelope: &NodeTaskEnvelope,
) -> Result<crate::protocol::types::ChatCompletionResponse> {
let chat_req = &envelope.payload.chat;
let ollama_req = self.chat_request_to_ollama(chat_req, &envelope.model);
let ollama_resp = self.ollama_client.chat(&ollama_req).await?;
Ok(self.ollama_response_to_chat(&ollama_resp, &envelope.model))
}
#[allow(dead_code)] fn chat_request_to_ollama(
&self,
chat_req: &crate::protocol::types::ChatCompletionRequest,
model: &str,
) -> OllamaChatRequest {
let messages: Vec<crate::protocol::ollama::OllamaMessage> = chat_req
.messages
.iter()
.map(|msg| crate::protocol::ollama::OllamaMessage {
role: msg.role.to_string(),
content: msg.content.clone(),
})
.collect();
OllamaChatRequest {
model: model.to_string(),
messages,
stream: false, }
}
#[allow(dead_code)] fn ollama_response_to_chat(
&self,
ollama_resp: &crate::protocol::ollama::OllamaChatResponse,
model: &str,
) -> crate::protocol::types::ChatCompletionResponse {
use crate::protocol::types::{
ChatCompletionResponse, CompletionChoice, MessageContent, Usage,
};
use uuid::Uuid;
let choice_id = format!("chatcmpl-{}", Uuid::new_v4().simple());
ChatCompletionResponse {
id: choice_id,
object: "chat.completion".to_string(),
created: chrono::Utc::now().timestamp(),
model: model.to_string(),
choices: vec![CompletionChoice {
index: 0,
message: MessageContent {
role: "assistant".to_string(),
content: ollama_resp.message.content.clone(),
},
finish_reason: Some("stop".to_string()),
}],
usage: Usage {
prompt_tokens: ollama_resp.prompt_eval_count,
completion_tokens: ollama_resp.eval_count,
total_tokens: ollama_resp.prompt_eval_count + ollama_resp.eval_count,
},
}
}
#[allow(dead_code)] async fn complete_with_retry(
&self,
task_id: uuid::Uuid,
lease_id: uuid::Uuid,
result: NodeTaskResult,
deadline_ms: i64,
grace_until_ms: i64,
) {
let req = NodeTaskCompleteRequest {
protocol_version: "node.v1".to_string(),
node_id: self.session.node_id,
session_id: self.session.session_id,
task_id,
lease_id,
result,
};
let now_ms = Utc::now().timestamp_millis();
let retry_deadline = if now_ms < deadline_ms {
deadline_ms
} else {
grace_until_ms
};
let max_retry_duration = std::cmp::max(0, retry_deadline - now_ms);
if max_retry_duration <= 0 {
warn!(
"Task {} past grace period, attempting one-shot complete",
task_id
);
match self.client.complete(task_id, &req).await {
Ok(resp) => {
info!(
"Task {} completed (one-shot): action={:?}",
task_id, resp.action
);
*self.node_status.lock().await = resp.node_status.clone();
*self.server_failure_count.lock().await = resp.server_failure_count;
*self.failure_threshold.lock().await = resp.failure_threshold;
self.log_complete_response(task_id, &resp);
}
Err(e) => {
error!("Task {} one-shot complete failed: {}", task_id, e);
}
}
return;
}
let max_retries =
std::cmp::max(1, (max_retry_duration as f64 / 1000.0).ceil() as usize / 5);
let retry_strategy = ExponentialBackoff::from_millis(100)
.max_delay(std::time::Duration::from_secs(5))
.take(max_retries);
info!(
"Starting complete retry for task {}: max_duration={}ms, max_retries={}",
task_id, max_retry_duration, max_retries
);
match Retry::spawn(retry_strategy, || async {
match self.client.complete(task_id, &req).await {
Ok(resp) => {
info!("Task {} completed: action={:?}", task_id, resp.action);
*self.node_status.lock().await = resp.node_status.clone();
*self.server_failure_count.lock().await = resp.server_failure_count;
*self.failure_threshold.lock().await = resp.failure_threshold;
self.log_complete_response(task_id, &resp);
Ok(resp)
}
Err(e) => {
warn!("Complete failed for task {}: {}", task_id, e);
Err(e)
}
}
})
.await
{
Ok(_) => {
info!("Task {} complete succeeded", task_id);
}
Err(e) => {
error!("Task {} complete failed after retries: {}", task_id, e);
}
}
}
#[allow(dead_code)] fn log_complete_response(&self, task_id: uuid::Uuid, resp: &NodeTaskCompleteResponse) {
info!(
"Complete response for task {}: action={:?}, task_status={}, node_status={}, failure_count={}/{}",
task_id,
resp.action,
resp.task_status,
resp.node_status,
resp.server_failure_count,
resp.failure_threshold
);
if resp.node_status == "excluded" {
warn!(
"Node EXCLUDED after task {} complete, will stop poll but continue heartbeat",
task_id
);
}
}
}
impl Clone for TaskExecutor {
fn clone(&self) -> Self {
Self {
client: self.client.clone(),
ollama_client: self.ollama_client.clone(),
session: self.session.clone(),
node_status: self.node_status.clone(),
server_failure_count: self.server_failure_count.clone(),
failure_threshold: self.failure_threshold.clone(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::protocol::types::{ChatCompletionRequest, Message, MessageRole};
#[test]
fn test_chat_request_to_ollama_conversion() {
let chat_req = ChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![
Message {
role: MessageRole::System,
content: "You are a helpful assistant.".to_string(),
},
Message {
role: MessageRole::User,
content: "Hello!".to_string(),
},
],
stream: Some(false),
max_tokens: None,
temperature: None,
top_p: None,
n: None,
stop: None,
};
let executor = create_test_executor();
let ollama_req = executor.chat_request_to_ollama(&chat_req, "test-model");
assert_eq!(ollama_req.model, "test-model");
assert_eq!(ollama_req.messages.len(), 2);
assert_eq!(ollama_req.messages[0].role, "system");
assert_eq!(ollama_req.messages[1].role, "user");
assert!(!ollama_req.stream);
}
fn create_test_executor() -> TaskExecutor {
use crate::client::{KeyComputeClient, OllamaClient};
use crate::protocol::types::NodeCapabilities;
use crate::storage::SessionData;
use std::sync::Arc;
let client = Arc::new(KeyComputeClient::new("http://localhost:3000"));
let ollama_client = Arc::new(OllamaClient::new("http://localhost:11434"));
let session = SessionData {
node_id: uuid::Uuid::new_v4(),
session_id: uuid::Uuid::new_v4(),
session_token: "test-token".to_string(),
capabilities: NodeCapabilities {
runtime: "ollama".to_string(),
models: vec![],
},
poll_timeout_secs: 30,
};
TaskExecutor::new(client, ollama_client, session)
}
#[test]
fn test_chat_request_to_ollama_basic() {
use crate::protocol::types::{ChatCompletionRequest, Message};
let executor = create_test_executor();
let chat_req = ChatCompletionRequest::new(
"deepseek-chat",
vec![
Message::system("You are a helpful assistant"),
Message::user("Hello!"),
],
);
let ollama_req = executor.chat_request_to_ollama(&chat_req, "deepseek-chat");
assert_eq!(ollama_req.model, "deepseek-chat");
assert!(!ollama_req.stream);
assert_eq!(ollama_req.messages.len(), 2);
assert_eq!(ollama_req.messages[0].role, "system");
assert_eq!(
ollama_req.messages[0].content,
"You are a helpful assistant"
);
assert_eq!(ollama_req.messages[1].role, "user");
assert_eq!(ollama_req.messages[1].content, "Hello!");
}
#[test]
fn test_chat_request_to_ollama_multiple_messages() {
use crate::protocol::types::{ChatCompletionRequest, Message};
let executor = create_test_executor();
let chat_req = ChatCompletionRequest {
model: "llama3".to_string(),
messages: vec![
Message::system("System prompt"),
Message::user("Question 1"),
Message::assistant("Answer 1"),
Message::user("Question 2"),
],
stream: Some(false),
max_tokens: None,
temperature: None,
top_p: None,
n: None,
stop: None,
};
let ollama_req = executor.chat_request_to_ollama(&chat_req, "llama3");
assert_eq!(ollama_req.model, "llama3");
assert_eq!(ollama_req.messages.len(), 4);
assert_eq!(ollama_req.messages[0].role, "system");
assert_eq!(ollama_req.messages[1].role, "user");
assert_eq!(ollama_req.messages[2].role, "assistant");
assert_eq!(ollama_req.messages[3].role, "user");
}
#[test]
fn test_ollama_response_to_chat_basic() {
use crate::protocol::ollama::OllamaChatResponse;
use crate::protocol::ollama::OllamaMessage;
let executor = create_test_executor();
let ollama_resp = OllamaChatResponse {
model: "deepseek-chat".to_string(),
created_at: "2024-01-01T00:00:00Z".to_string(),
message: OllamaMessage {
role: "assistant".to_string(),
content: "Hello! How can I help you?".to_string(),
},
done: true,
total_duration: 1000000000,
load_duration: 500000000,
prompt_eval_count: 10,
eval_count: 20,
};
let chat_resp = executor.ollama_response_to_chat(&ollama_resp, "deepseek-chat");
assert_eq!(chat_resp.object, "chat.completion");
assert_eq!(chat_resp.model, "deepseek-chat");
assert_eq!(chat_resp.choices.len(), 1);
assert_eq!(chat_resp.choices[0].index, 0);
assert_eq!(chat_resp.choices[0].message.role, "assistant");
assert_eq!(
chat_resp.choices[0].message.content,
"Hello! How can I help you?"
);
assert_eq!(chat_resp.choices[0].finish_reason, Some("stop".to_string()));
assert_eq!(chat_resp.usage.prompt_tokens, 10);
assert_eq!(chat_resp.usage.completion_tokens, 20);
assert_eq!(chat_resp.usage.total_tokens, 30);
}
#[test]
fn test_ollama_response_to_chat_empty_counts() {
use crate::protocol::ollama::OllamaChatResponse;
use crate::protocol::ollama::OllamaMessage;
let executor = create_test_executor();
let ollama_resp = OllamaChatResponse {
model: "llama3".to_string(),
created_at: "2024-01-01T00:00:00Z".to_string(),
message: OllamaMessage {
role: "assistant".to_string(),
content: "Response".to_string(),
},
done: true,
total_duration: 0,
load_duration: 0,
prompt_eval_count: 0,
eval_count: 0,
};
let chat_resp = executor.ollama_response_to_chat(&ollama_resp, "llama3");
assert_eq!(chat_resp.usage.prompt_tokens, 0);
assert_eq!(chat_resp.usage.completion_tokens, 0);
assert_eq!(chat_resp.usage.total_tokens, 0);
}
#[test]
fn test_chat_request_to_ollama_preserves_content() {
use crate::protocol::types::{ChatCompletionRequest, Message};
let executor = create_test_executor();
let complex_content = r#"{
"type": "code",
"language": "rust",
"code": "fn main() { println!(\"Hello\"); }"
}"#;
let chat_req =
ChatCompletionRequest::new("deepseek-coder", vec![Message::user(complex_content)]);
let ollama_req = executor.chat_request_to_ollama(&chat_req, "deepseek-coder");
assert_eq!(ollama_req.messages.len(), 1);
assert_eq!(ollama_req.messages[0].content, complex_content);
}
#[test]
fn test_ollama_response_to_chat_generates_unique_id() {
use crate::protocol::ollama::OllamaChatResponse;
use crate::protocol::ollama::OllamaMessage;
let executor = create_test_executor();
let ollama_resp = OllamaChatResponse {
model: "test".to_string(),
created_at: "2024-01-01T00:00:00Z".to_string(),
message: OllamaMessage {
role: "assistant".to_string(),
content: "test".to_string(),
},
done: true,
total_duration: 0,
load_duration: 0,
prompt_eval_count: 0,
eval_count: 0,
};
let chat_resp1 = executor.ollama_response_to_chat(&ollama_resp, "test");
let chat_resp2 = executor.ollama_response_to_chat(&ollama_resp, "test");
assert_ne!(chat_resp1.id, chat_resp2.id);
assert!(chat_resp1.id.starts_with("chatcmpl-"));
assert!(chat_resp2.id.starts_with("chatcmpl-"));
}
#[test]
fn test_execute_result_construction_success() {
use crate::protocol::types::NodeTaskResult;
let ollama_resp = crate::protocol::ollama::OllamaChatResponse {
model: "deepseek-chat".to_string(),
created_at: "2024-01-01T00:00:00Z".to_string(),
message: crate::protocol::ollama::OllamaMessage {
role: "assistant".to_string(),
content: "Hello!".to_string(),
},
done: true,
total_duration: 1000000000,
load_duration: 500000000,
prompt_eval_count: 10,
eval_count: 20,
};
let executor = create_test_executor();
let chat_resp = executor.ollama_response_to_chat(&ollama_resp, "deepseek-chat");
let result = NodeTaskResult::Succeeded {
response: chat_resp,
};
match result {
NodeTaskResult::Succeeded { response } => {
assert_eq!(response.model, "deepseek-chat");
assert_eq!(response.choices[0].message.content, "Hello!");
assert_eq!(response.usage.total_tokens, 30);
}
NodeTaskResult::Failed { .. } => panic!("Expected Succeeded variant"),
}
}
#[test]
fn test_execute_result_construction_failure() {
use crate::protocol::types::NodeTaskResult;
let error_msg = "Ollama API error: model not found";
let result = NodeTaskResult::Failed {
code: "ollama_error".to_string(),
message: error_msg.to_string(),
};
match result {
NodeTaskResult::Failed { code, message } => {
assert_eq!(code, "ollama_error");
assert_eq!(message, error_msg);
}
NodeTaskResult::Succeeded { .. } => panic!("Expected Failed variant"),
}
}
#[test]
fn test_task_deadline_and_grace_period() {
use chrono::{Duration, Utc};
let now = Utc::now();
let deadline = now + Duration::seconds(60); let grace_until = deadline + Duration::seconds(30);
assert!(deadline > now);
assert!(grace_until > deadline);
assert_eq!((grace_until - deadline).num_seconds(), 30);
assert!(now < deadline);
let after_deadline = deadline + Duration::seconds(10);
assert!(after_deadline > deadline);
assert!(after_deadline < grace_until);
let after_grace = grace_until + Duration::seconds(1);
assert!(after_grace > grace_until);
}
}