use crate::activity::copy_panic_payload;
use crate::activity::ActivityId;
use crate::activity::FunctionLabel;
use crate::activity::SquashBufferFactory;
use crate::activity::TaskItem;
use crate::activity::TaskItemBuilder;
use crate::args::RemoteSend;
use crate::collective;
use crate::executor;
use crate::global_id::ActivityIdMethods;
use crate::global_id::FinishIdMethods;
use crate::logging;
use crate::logging::*;
use crate::meta_data;
use crate::network::context::CommunicationContext;
use crate::network::MessageHandler;
use crate::network::Rank;
use crate::place;
use crate::runtime::init_task_item_channels;
use crate::runtime::init_worker_task_queue;
use crate::runtime::message_recv_callback;
use crate::runtime::take_message_buffer_receiver;
use crate::runtime::take_worker_task_receiver;
use crate::runtime::ApgasContext;
use crate::runtime::ConcreteContext;
use crate::runtime::Distributor;
use crate::runtime::ExecutionHub;
use crate::runtime_meta;
use futures::future::BoxFuture;
use futures::future::FutureExt;
use futures::Future;
use std::thread;
pub fn send_activity_result<T: RemoteSend>(
ctx: impl ApgasContext,
a_id: ActivityId,
fn_id: FunctionLabel,
waited: bool,
result: std::thread::Result<T>,
) {
let finish_id = a_id.get_finish_id();
let stripped_result = match &result {
Ok(_) => thread::Result::<()>::Ok(()),
Err(e) => thread::Result::<()>::Err(copy_panic_payload(e)),
};
let mut builder = TaskItemBuilder::new(fn_id, finish_id.get_place(), a_id);
let spawned_activities = ctx.spawned(); builder.ret(stripped_result); builder.sub_activities(spawned_activities.clone());
let item = builder.build_box();
ConcreteContext::send(item);
if waited {
let mut builder = TaskItemBuilder::new(fn_id, a_id.get_spawned_place(), a_id);
builder.ret(result); builder.sub_activities(spawned_activities);
builder.waited();
let item = builder.build_box();
ConcreteContext::send(item);
}
}
fn worker_dispatch(item: TaskItem) -> BoxFuture<'static, ()> {
let fn_id = item.function_id();
let resovled = runtime_meta::get_func_table().get(&fn_id).unwrap().fn_ptr;
resovled(item)
}
pub fn init_collective_operator<T: MessageHandler>(ctx: &CommunicationContext<T>) {
collective::set_coll(Box::new(ctx.collective_operator()));
}
pub fn genesis<F, FOUT, MOUT>(main: F) -> MOUT
where
F: FnOnce(Vec<String>) -> FOUT,
FOUT: Future<Output = MOUT> + Send + 'static,
MOUT: Send + 'static,
{
logging::setup_logger().unwrap();
meta_data::show_data();
runtime_meta::init_func_table();
runtime_meta::init_helpers();
let msg_recv_callback =
|src: Rank, data: &[u8]| message_recv_callback::<SquashBufferFactory>(src, data);
let mut context = CommunicationContext::new(msg_recv_callback);
let world_size = context.world_size();
let main_fut = main(context.cmd_args().to_vec());
let factory = Box::new(SquashBufferFactory::new());
init_worker_task_queue();
init_task_item_channels();
place::init_here(context.here().as_place());
place::init_world_size(world_size);
init_collective_operator(&context);
let sender = context.single_sender();
let buffer_receiver = take_message_buffer_receiver();
let distributor = Distributor::new(factory, world_size, sender, buffer_receiver);
let mut hub = ExecutionHub::new(distributor);
let trigger = hub.get_trigger();
info!("start network loop");
let (init_done_s, init_done_r) = std::sync::mpsc::channel::<()>();
let network_thread = thread::spawn(move || {
context.init();
init_done_s.send(()).unwrap();
context.run();
});
info!("start execution hub");
let hub_thread = thread::spawn(move || hub.run());
let rt = executor::runtime::Builder::new_multi_thread()
.worker_threads(*meta_data::NUM_CPUS)
.thread_name("crayfish-worker")
.build()
.unwrap();
let worker_loop = async move {
let mut task_receiver = take_worker_task_receiver();
while let Some(task) = task_receiver.recv().await {
executor::spawn(worker_dispatch(*task));
}
debug!("worker task loop stops");
};
let ret = rt.block_on(async move {
executor::spawn(worker_loop);
executor::spawn(main_fut).await.unwrap()
});
init_done_r.recv().unwrap();
let mut coll = collective::take_coll();
rt.block_on(coll.barrier().map(|r| r.unwrap()));
trigger.stop();
drop(coll);
hub_thread.join().unwrap();
network_thread.join().unwrap();
drop(rt);
#[cfg(feature = "trace")]
crate::trace::print_profiling();
info!("exit gracefully");
ret
}