use std::collections::HashMap;
use std::sync::Arc;
use dashmap::DashMap;
use tokio::sync::mpsc;
use tokio_stream::StreamExt;
use tokio_stream::wrappers::ReceiverStream;
use tonic::{Request, Response, Status, transport::Server};
use prompt_graph_core::proto::{ChangeValueWithCounter, Empty, ExecutionStatus, File, FileAddressedChangeValueWithCounter, FilteredPollNodeWillExecuteEventsRequest, InputProposal, ListBranchesRes, ListRegisteredGraphsResponse, NodeWillExecuteOnBranch, ParquetFile, QueryAtFrame, QueryAtFrameResponse, RequestAckNodeWillExecuteEvent, RequestAtFrame, RequestFileMerge, RequestInputProposalResponse, RequestListBranches, RequestNewBranch, RequestOnlyId, RespondPollNodeWillExecuteEvents, UpsertPromptLibraryRecord};
use prompt_graph_core::proto::execution_runtime_server::{ExecutionRuntime, ExecutionRuntimeServer};
use log::debug;
use prompt_graph_core::execution_router::evaluate_changes_against_node;
use prompt_graph_core::build_runtime_graph::graph_parse::query_path_from_query_string;
use crate::db_operations::get_change_counter_for_branch;
use crate::db_operations::branches::{create_branch, create_root_branch, list_branches};
use crate::db_operations::changes::scan_all_resolved_changes;
use crate::db_operations::input_proposals_and_responses::insert_input_response;
use crate::db_operations::playback::pause_execution_at_frame;
use crate::db_operations::playback::play_execution_at_frame;
use crate::db_operations::changes::scan_all_pending_changes;
use crate::db_operations::custom_node_execution::insert_custom_node_execution;
use crate::db_operations::graph_mutations::{insert_pending_graph_mutation, scan_all_file_mutations_on_branch};
use crate::db_operations::input_proposals_and_responses::scan_all_input_proposals;
use crate::db_operations::executing_nodes::{move_will_execute_event_to_complete, move_will_execute_event_to_in_progress, scan_all_custom_node_will_execute_events, scan_all_will_execute_events, scan_all_will_execute_pending_events};
use crate::db_operations::files::{insert_executor_file_existence_by_id, scan_all_executor_files};
use crate::db_operations::prompt_library::insert_prompt_library_mutation;
use crate::executor::{Executor, InternalStateHandler};
#[derive(Debug)]
pub struct MyExecutionRuntime {
db: Arc<sled::Db>,
executor_started: Arc<DashMap<String, bool>>
}
impl MyExecutionRuntime {
fn new(file_path: Option<String>) -> Self {
let db_config = sled::Config::default();
let db_config = if let Some(path) = file_path {
if path.contains(":memory:") {
db_config.temporary(true)
} else {
db_config.path(path)
}
} else {
db_config.path("/tmp/prompt-graph".to_string())
};
MyExecutionRuntime {
db: Arc::new(db_config.open().unwrap()),
executor_started: Arc::new(DashMap::new())
}
}
fn get_tree(&self, id: &str) -> sled::Tree {
let db = self.db.clone();
db.open_tree(id).unwrap()
}
}
#[tonic::async_trait]
impl ExecutionRuntime for MyExecutionRuntime {
#[tracing::instrument]
async fn run_query(&self, request: Request<QueryAtFrame>) -> Result<Response<QueryAtFrameResponse>, Status> {
debug!("Received run_query request: {:?}", &request);
let query = request.get_ref().query.as_ref().unwrap();
let branch = request.get_ref().branch;
let counter = request.get_ref().frame;
let tree = self.get_tree(&request.get_ref().id.clone());
let state = InternalStateHandler {
tree: &tree,
branch,
counter
};
let paths = query_path_from_query_string(&query.query.clone().unwrap()).unwrap();
if let Some(values) = evaluate_changes_against_node(&state, &paths) {
Ok(Response::new(QueryAtFrameResponse {
values
}))
} else {
Ok(Response::new(QueryAtFrameResponse {
values: vec![]
}))
}
}
#[tracing::instrument]
async fn merge(&self, request: Request<RequestFileMerge>) -> Result<Response<ExecutionStatus>, Status> {
debug!("Received merge request: {:?}", request);
let file = request.get_ref().file.as_ref().unwrap();
let branch = request.get_ref().branch;
let id = file.id.clone();
let tree = self.get_tree(&request.get_ref().id.clone());
insert_pending_graph_mutation(&tree, branch, file.clone());
let monotonic_counter = get_change_counter_for_branch(&tree, branch);
Ok(Response::new(ExecutionStatus{ id, monotonic_counter, branch}))
}
#[tracing::instrument]
async fn current_file_state(&self, request: Request<RequestOnlyId>) -> Result<Response<File>, Status> {
debug!("Received current_file_state request: {:?}", request);
let tree = &self.get_tree(&request.get_ref().id.clone());
let id = request.get_ref().id.clone();
let branch = &request.get_ref().branch;
let mutations = scan_all_file_mutations_on_branch(tree, *branch);
let mut name_map = HashMap::new();
let mut name_map_version_markers: HashMap<String, (u64, u64)> = HashMap::new();
let mut new_file = File {
id,
nodes: vec![],
};
for (_is_resolved, k, mutation) in mutations {
for node in mutation.nodes {
let node_insert = node.clone();
let name = node.core.unwrap().name;
if let Some(marker) = name_map_version_markers.get(&name) {
if (*marker).1 < k.1 {
name_map_version_markers.insert(name.clone(), k);
name_map.insert(name.clone(), node_insert);
}
} else {
name_map_version_markers.insert(name.clone(), k);
name_map.insert(name.clone(), node_insert);
}
}
}
for (_, node) in name_map {
new_file.nodes.push(node);
}
Ok(Response::new(new_file))
}
#[tracing::instrument]
async fn get_parquet_history(&self, request: Request<RequestOnlyId>) -> Result<Response<ParquetFile>, Status> {
debug!("Received get_parquet_history request: {:?}", request);
let _tree = &self.get_tree(&request.get_ref().id.clone());
todo!()
}
#[tracing::instrument]
async fn play(&self, request: Request<RequestAtFrame>) -> Result<Response<ExecutionStatus>, Status> {
debug!("Received play request: {:?}", request);
let exec = self.executor_started.clone();
let id: &String = &request.get_ref().id.clone();
let branch = request.get_ref().branch.clone();
let tree = self.get_tree(id);
play_execution_at_frame(&tree, request.get_ref().frame);
if exec.get(id).is_some() {
return Ok(Response::new(ExecutionStatus{ id: id.clone(), monotonic_counter: 0, branch }));
}
let root_tree = self.get_tree("root");
insert_executor_file_existence_by_id(&root_tree, id.clone());
create_root_branch(&tree);
let move_tree = tree.clone();
let _ = tokio::spawn( async move {
let mut executor = Executor::new(move_tree);
executor.run().await;
});
let monotonic_counter = get_change_counter_for_branch(&tree, branch);
exec.insert(id.clone(), true);
Ok(Response::new(ExecutionStatus{ id: id.clone(), monotonic_counter, branch }))
}
#[tracing::instrument]
async fn pause(&self, request: Request<RequestAtFrame>) -> Result<Response<ExecutionStatus>, Status> {
debug!("Received pause request: {:?}", request);
let id = &request.get_ref().id.clone();
let branch = request.get_ref().branch.clone();
let tree = self.get_tree(id);
pause_execution_at_frame(&tree, request.get_ref().frame);
let monotonic_counter = get_change_counter_for_branch(&tree, branch);
Ok(Response::new(ExecutionStatus{ id: id.clone(), monotonic_counter, branch}))
}
#[tracing::instrument]
async fn branch(&self, request: Request<RequestNewBranch>) -> Result<Response<ExecutionStatus>, Status> {
debug!("Received branch request: {:?}", request);
let id = &request.get_ref().id.clone();
let source_branch_id = request.get_ref().source_branch_id.clone();
let tree = self.get_tree(id);
let new_branch_id = create_branch(&tree, source_branch_id, 0);
let monotonic_counter = get_change_counter_for_branch(&tree, new_branch_id);
Ok(Response::new(ExecutionStatus{ id: id.clone(), monotonic_counter, branch: new_branch_id}))
}
#[tracing::instrument]
async fn list_branches(&self, request: Request<RequestListBranches>) -> Result<Response<ListBranchesRes>, Status> {
debug!("Received list_branches request: {:?}", request);
let id = &request.get_ref().id.clone();
Ok(Response::new(
ListBranchesRes {
id: id.clone(),
branches: list_branches(&self.get_tree(id)).collect()
}
))
}
#[tracing::instrument]
async fn list_registered_graphs(&self, request: tonic::Request<prompt_graph_core::proto::Empty>) -> Result<Response<ListRegisteredGraphsResponse>, Status> {
debug!("Received list_registered_graphs request: {:?}", request);
let root_tree = self.get_tree("root");
Ok(Response::new(ListRegisteredGraphsResponse {
ids: scan_all_executor_files(&root_tree).collect()
}))
}
type ListInputProposalsStream = ReceiverStream<Result<InputProposal, Status>>;
#[tracing::instrument]
async fn list_input_proposals(&self, request: Request<RequestOnlyId>) -> Result<Response<Self::ListInputProposalsStream>, Status> {
debug!("Received list_input_proposals request: {:?}", request);
let (tx, rx) = mpsc::channel(4);
let tree = self.get_tree(&request.get_ref().id.clone());
tokio::spawn(async move {
for prop in scan_all_input_proposals(&tree) {
tx.send(Ok(prop)).await.unwrap();
}
});
Ok(Response::new(ReceiverStream::new(rx)))
}
#[tracing::instrument]
async fn respond_to_input_proposal(&self, request: Request<RequestInputProposalResponse>) -> Result<Response<Empty>, Status> {
debug!("Received respond_to_input_proposal request: {:?}", request);
let tree = self.get_tree(&request.get_ref().id.clone());
let rec = request.get_ref().clone();
insert_input_response(&tree, rec);
Ok(Response::new(Empty::default()))
}
type ListChangeEventsStream = ReceiverStream<Result<ChangeValueWithCounter, Status>>;
#[tracing::instrument]
async fn list_change_events(&self, request: Request<RequestOnlyId>) -> Result<Response<Self::ListChangeEventsStream>, Status> {
debug!("Received list_change_events request: {:?}", request);
let (tx, rx) = mpsc::channel(4);
let tree = self.get_tree(&request.get_ref().id.clone());
tokio::spawn(async move {
for prop in scan_all_pending_changes(&tree) {
tx.send(Ok(prop)).await.unwrap();
}
for prop in scan_all_resolved_changes(&tree) {
tx.send(Ok(prop)).await.unwrap();
}
});
Ok(Response::new(ReceiverStream::new(rx)))
}
type ListNodeWillExecuteEventsStream = ReceiverStream<Result<NodeWillExecuteOnBranch, Status>>;
async fn list_node_will_execute_events(&self, request: Request<RequestOnlyId>) -> Result<Response<Self::ListNodeWillExecuteEventsStream>, Status> {
debug!("Received list_node_will_execute_events request: {:?}", request);
let (tx, rx) = mpsc::channel(4);
let tree = self.get_tree(&request.get_ref().id.clone());
tokio::spawn(async move {
for prop in scan_all_will_execute_events(&tree) {
tx.send(Ok(prop)).await.unwrap();
}
});
Ok(Response::new(ReceiverStream::new(rx)))
}
async fn poll_custom_node_will_execute_events(&self, request: Request<FilteredPollNodeWillExecuteEventsRequest>) -> Result<Response<RespondPollNodeWillExecuteEvents>, Status> {
debug!("Received poll_custom_node_will_execute_events request: {:?}", request);
let tree = self.get_tree(&request.get_ref().id.clone());
let will_exec_events = scan_all_custom_node_will_execute_events(&tree);
Ok(Response::new(RespondPollNodeWillExecuteEvents {
node_will_execute_events: will_exec_events.collect(),
}))
}
async fn ack_node_will_execute_event(&self, request: Request<RequestAckNodeWillExecuteEvent>) -> Result<Response<ExecutionStatus>, Status> {
debug!("Received ack_node_will_execute_event request: {:?}", request);
let tree = self.get_tree(&request.get_ref().id.clone());
let branch = request.get_ref().branch.clone();
let counter = request.get_ref().counter.clone();
move_will_execute_event_to_in_progress(&tree, true, branch, counter);
Ok(Response::new(ExecutionStatus::default()))
}
#[tracing::instrument]
async fn push_worker_event(&self, request: Request<FileAddressedChangeValueWithCounter>) -> Result<Response<ExecutionStatus>, Status> {
debug!("Received push_worker_event request: {:?}", request);
let tree = self.get_tree(&request.get_ref().id.clone());
let branch = request.get_ref().branch.clone();
let counter = request.get_ref().counter.clone();
let change = request.into_inner().change.expect("Must have a change value");
let _node_will_exec = move_will_execute_event_to_complete(&tree, true, branch, counter);
insert_custom_node_execution(&tree, change);
Ok(Response::new(ExecutionStatus::default()))
}
#[tracing::instrument]
async fn push_template_partial(&self, request: Request<UpsertPromptLibraryRecord>) -> Result<Response<ExecutionStatus>, Status> {
let tree = self.get_tree(&request.get_ref().id.clone());
insert_prompt_library_mutation(&tree, request.get_ref());
Ok(Response::new(ExecutionStatus::default()))
}
}
#[tokio::main]
pub async fn run_server(url_server: String, file_path: Option<String>) -> Result<(), Box<dyn std::error::Error>> {
let url = url_server
.replace("http://", "")
.replace("https://", "")
.replace("localhost", "127.0.0.1");
let addr = format!("{}", url).parse().unwrap();
let server = MyExecutionRuntime::new(file_path);
println!("ExecutionRuntime listening on {}", addr);
Server::builder()
.add_service(ExecutionRuntimeServer::new(server))
.serve(addr)
.await?;
Ok(())
}
#[cfg(test)]
mod tests {
use prompt_graph_core::templates::render_template_prompt;
use super::*;
#[tokio::test]
async fn test_pushing_a_partial_template() {
let e = MyExecutionRuntime::new(Some(":memory:".to_string()));
e.push_template_partial(Request::new(UpsertPromptLibraryRecord {
description: None,
template: "Testing".to_string(),
name: "named".to_string(),
id: "test".to_string(),
})).await.unwrap();
let tree = e.get_tree(&"test".to_string());
}
}