use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use anyhow::anyhow;
use anyhow::Result;
use parking_lot::RwLock;
use serde_json::Value;
use tokio::runtime::Builder;
use tokio::runtime::Runtime;
use tokio::task;
use tokio::time;
use tokio_util::sync::CancellationToken;
use crate::backend_logger::BackendLogger;
use crate::edge;
use crate::edge::InitializationData;
use crate::expression::Expression;
use crate::logger::Logger;
use crate::node_props::NodeProps;
use crate::node_props::NodePropsType;
use crate::query::Query;
use crate::types::InitQuery;
pub struct Context {
#[allow(dead_code)]
runtime: Option<Runtime>,
cancel_token: CancellationToken,
update_listener: Option<task::JoinHandle<()>>,
init_listener: Option<task::JoinHandle<()>>,
flush_logs_listener: Option<task::JoinHandle<()>>,
token: String,
init_query: InitQuery,
query: Query,
variable_values: Value,
branch_name: Option<String>,
edge_base_url: String,
remote_logging_base_url: String,
language: edge::Language,
logger: Logger,
pub backend_logger: BackendLogger,
pub init_data: Option<InitializationData>,
}
impl Drop for Context {
fn drop(&mut self) {
let runtime = self.runtime.take();
if runtime.is_none() {
return;
}
let runtime = runtime.unwrap();
self.cancel_token.cancel();
if let Some(update_listener) = self.update_listener.take() {
_ = runtime.block_on(update_listener);
}
if let Some(init_listener) = self.init_listener.take() {
_ = runtime.block_on(init_listener);
}
if let Some(flush_logs_listener) = self.flush_logs_listener.take() {
_ = runtime.block_on(flush_logs_listener);
}
runtime.block_on(async {
_ = BackendLogger::create_logs(
self.remote_logging_base_url.clone(),
self.backend_logger.collect_logs(),
)
.await
});
runtime.shutdown_timeout(time::Duration::from_secs(1));
}
}
impl Context {
pub fn initialize(
variable_values: Value,
token: String,
init_query: InitQuery,
query: Query,
branch_name: Option<String>,
init_data_refresh_interval_ms: u64,
logs_flush_interval_ms: u64,
edge_base_url: String,
remote_logging_base_url: String,
language: edge::Language,
fallback_init_data: Option<InitializationData>,
) -> Result<NodeProps> {
let runtime = Builder::new_multi_thread()
.worker_threads(1)
.enable_all()
.build()?;
let _rt_guard = runtime.enter();
let mut inner_context = Self {
cancel_token: CancellationToken::new(),
logger: Logger::new(),
runtime: Some(runtime),
update_listener: None,
init_listener: None,
flush_logs_listener: None,
language,
token: token.to_owned(),
query,
init_query,
variable_values,
branch_name,
edge_base_url,
remote_logging_base_url,
backend_logger: BackendLogger::new(token.clone()),
init_data: None,
};
if let Some(init_data) = fallback_init_data {
inner_context.init(init_data)?;
}
let context = Arc::new(RwLock::new(inner_context));
if init_data_refresh_interval_ms > 0 {
let update_listener = task::spawn(Self::listen_for_updates(
context.clone(),
context.write().cancel_token.clone(),
time::interval(Duration::from_millis(init_data_refresh_interval_ms)),
));
context.write().update_listener = Some(update_listener);
}
let init_listener = task::spawn(Self::init_from_hypertune_edge(context.clone()));
context.write().init_listener = Some(init_listener);
if logs_flush_interval_ms > 0 {
let flush_logs_listener = task::spawn(Self::periodically_flush_logs(
context.clone(),
context.write().cancel_token.clone(),
time::interval(Duration::from_millis(logs_flush_interval_ms)),
));
context.write().flush_logs_listener = Some(flush_logs_listener);
}
Ok(NodeProps {
r#type: NodePropsType::None,
context,
expression: None,
commit_hash: None,
parent: None,
step: None,
})
}
pub fn wait_for_initialization(context: Arc<RwLock<Context>>) {
let init_listener = {
let mut context_guard = context.write();
context_guard.init_listener.take()
};
if let Some(listener) = init_listener {
let runtime = Runtime::new().unwrap();
runtime.block_on(listener).unwrap();
}
}
async fn init_from_hypertune_edge(context: Arc<RwLock<Context>>) {
let (token, init_query, variables, language, edge_base_url, branch_name) = {
let context_guard = context.read();
(
context_guard.token.clone(),
context_guard.init_query.clone(),
context_guard.variable_values.clone(),
context_guard.language,
context_guard.edge_base_url.clone(),
context_guard.branch_name.clone(),
)
};
let init_data_result = edge::init_request(
&token,
branch_name,
&init_query,
&variables,
language,
edge_base_url.as_str(),
)
.await;
match init_data_result {
Ok(init_data) => {
let mut mut_context = context.write();
let init_result = mut_context.init(init_data);
if let Err(err) = init_result {
Context::log_error(
context.clone(),
format!(
"Failed to initialize from latest commit data: {}",
err.to_string()
),
)
}
}
Err(err) => Context::log_error(
context.clone(),
format!("Failed to fetch latest commit data: {}", err.to_string()),
),
}
}
async fn listen_for_updates(
context: Arc<RwLock<Context>>,
cancel_token: CancellationToken,
mut interval: time::Interval,
) {
let (token, init_query, variables, language, edge_base_url, branch_name) = {
let context = context.read();
(
context.token.clone(),
context.init_query.clone(),
context.variable_values.clone(),
context.language,
context.edge_base_url.clone(),
context.branch_name.clone(),
)
};
loop {
tokio::select! {
_ = cancel_token.cancelled() => {
break;
}
_ = interval.tick() => {}
}
let response = edge::hash_request(
&token,
branch_name.clone(),
&init_query,
&variables,
language,
edge_base_url.as_str(),
)
.await;
match response {
Ok(response) => {
let (current_commit_id, current_commit_hash) = {
context
.read()
.init_data
.as_ref()
.map(|init_data| (init_data.commit_id, init_data.hash.clone()))
.unwrap_or((0, "".to_string()))
};
if current_commit_id > response.commit_id {
continue;
}
if current_commit_hash != response.hash {
Self::init_from_hypertune_edge(context.clone()).await
}
}
Err(err) => Context::log_error(
context.clone(),
format!("Failed to fetch latest commit hash: {}", err.to_string()),
),
}
}
}
async fn periodically_flush_logs(
context: Arc<RwLock<Context>>,
cancel_token: CancellationToken,
mut interval: time::Interval,
) {
loop {
tokio::select! {
_ = cancel_token.cancelled() => {
break;
}
_ = interval.tick() => {}
}
Self::flush_logs(context.clone()).await;
}
}
pub async fn flush_logs(context: Arc<RwLock<Context>>) {
let (create_logs_input, remote_logging_base_url) = {
let mut context = context.write();
(
context.backend_logger.collect_logs(),
context.remote_logging_base_url.clone(),
)
};
let result = BackendLogger::create_logs(remote_logging_base_url, create_logs_input).await;
if let Err(err) = result {
context
.write()
.logger
.log_error(format!("Failed to flush logs: {}", err.to_string()));
}
}
pub fn log_error(context: Arc<RwLock<Context>>, msg: String) {
let mut context = context.write();
context.backend_logger.log_error(msg.clone());
context.logger.log_error(msg);
}
pub fn reduce(&self, query: Option<&Query>, expression: Expression) -> Result<Expression> {
match &self.init_data {
Some(init_data) => expression.reduce(
query,
None,
HashMap::new(),
false,
init_data.splits.clone(),
HashMap::new(),
init_data.commit_config.clone(),
),
None => Err(anyhow!("No init data so cannot reduce expression")),
}
}
fn init(&mut self, mut init_data: InitializationData) -> Result<()> {
let reduced_expression = init_data.reduced_expression.reduce(
Some(&self.query),
None,
HashMap::new(),
false,
init_data.splits.clone(),
HashMap::new(),
init_data.commit_config.clone(),
)?;
init_data.reduced_expression = reduced_expression;
self.backend_logger
.set_commit_id(init_data.commit_id.to_string());
self.init_data = Some(init_data);
Ok(())
}
pub fn close(context: Arc<RwLock<Context>>) {
let (runtime, update_listener, init_listener, flush_logs_listener) = {
let mut context = context.write();
let runtime = context.runtime.take();
if runtime.is_some() {
context.cancel_token.cancel();
}
(
runtime,
context.update_listener.take(),
context.init_listener.take(),
context.flush_logs_listener.take(),
)
};
if runtime.is_none() {
return;
}
let runtime = runtime.unwrap();
if let Some(update_listener) = update_listener {
_ = runtime.block_on(update_listener);
}
if let Some(init_listener) = init_listener {
_ = runtime.block_on(init_listener);
}
if let Some(flush_logs_listener) = flush_logs_listener {
_ = runtime.block_on(flush_logs_listener);
}
let mut context = context.write();
runtime.block_on(async {
_ = BackendLogger::create_logs(
context.remote_logging_base_url.clone(),
context.backend_logger.collect_logs(),
)
.await
});
runtime.shutdown_timeout(time::Duration::from_secs(1));
}
#[cfg(test)]
pub fn test() -> Self {
use crate::expression::NoOpExpression;
use crate::split::CommitConfig;
use crate::split::SplitMap;
use crate::types::StoredQuery;
let runtime = Builder::new_multi_thread()
.worker_threads(1)
.enable_all()
.build()
.unwrap();
let _rt_guard = runtime.enter();
let update_listener = Some(task::spawn(async {}));
let init_listener = Some(task::spawn(async {}));
let flush_logs_listener = Some(task::spawn(async {}));
Self {
logger: Logger::new(),
runtime: Some(runtime),
update_listener,
init_listener,
flush_logs_listener,
cancel_token: CancellationToken::new(),
token: "".to_owned(),
query: HashMap::new(),
init_query: InitQuery::Stored(StoredQuery { id: "".to_string() }),
variable_values: serde_json::from_str("{}").unwrap(),
language: edge::Language::Rust,
branch_name: None,
edge_base_url: "".to_string(),
remote_logging_base_url: "".to_string(),
backend_logger: BackendLogger::new("".to_string()),
init_data: Some(InitializationData {
commit_id: 0,
reduced_expression: Expression::NoOp(NoOpExpression::transient()),
hash: "".to_string(),
commit_config: CommitConfig::new(),
splits: SplitMap::new(),
}),
}
}
}