use ini::Ini;
use log::{debug, error};
use pyo3::prelude::*;
use pythonize::{depythonize, pythonize};
use serde_value::Value;
use std::collections::{btree_map, BTreeMap};
use std::fmt;
use std::sync::atomic::{AtomicU8, Ordering};
use std::time::Duration;
use tokio::sync::{mpsc, Mutex, RwLock};
use tokio::time::sleep;
pub const STATE_STOPPED: u8 = 0;
pub const STATE_STARTING: u8 = 1;
pub const STATE_STOPPING: u8 = 2;
pub const STATE_STARTED: u8 = 0xff;
const DATACHANNEL_DEFAULT_BUFFER: usize = 1024;
const PIME_POLL_DELAY: Duration = Duration::from_millis(1);
#[macro_use]
extern crate lazy_static;
static ENGINE_STATE: AtomicU8 = AtomicU8::new(STATE_STOPPED);
#[derive(Debug, Eq, PartialEq)]
pub enum ErrorKind {
PyException,
PackError,
UnpackError,
ExecError,
InternalError,
PySyncEngineStateError,
}
impl fmt::Display for ErrorKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{}",
match self {
ErrorKind::PyException => "Python exception",
ErrorKind::PackError => "Data pack error",
ErrorKind::UnpackError => "Data unpack error",
ErrorKind::ExecError => "Task execution error",
ErrorKind::InternalError => "Internal error",
ErrorKind::PySyncEngineStateError => "Engine state error",
}
)
}
}
#[derive(Debug)]
pub struct Error {
pub kind: ErrorKind,
pub message: String,
pub exception: Option<String>,
pub traceback: Option<String>,
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.kind == ErrorKind::PyException {
let mut exc = "Python exception".to_owned();
if let Some(exception) = self.exception.as_ref() {
exc += " ";
exc += exception;
if !self.message.is_empty() {
exc += ":";
}
};
write!(f, "{} {}", exc, self.message)
} else {
write!(f, "{}: {}", self.kind, self.message)
}
}
}
impl From<PyErr> for Error {
fn from(e: PyErr) -> Error {
Error::new_internal(e)
}
}
impl<T> From<tokio::sync::mpsc::error::SendError<T>> for Error
where
T: std::fmt::Debug,
{
fn from(e: tokio::sync::mpsc::error::SendError<T>) -> Error {
Error::new_internal(e)
}
}
impl From<tokio::sync::TryLockError> for Error {
fn from(e: tokio::sync::TryLockError) -> Error {
Error::new_internal(e)
}
}
impl Error {
pub fn new<T: fmt::Display>(kind: ErrorKind, message: T) -> Self {
Self {
kind,
message: format!("{}", message),
exception: None,
traceback: None,
}
}
fn new_py(error: (String, String, String)) -> Self {
Self {
kind: ErrorKind::PyException,
exception: Some(error.0),
message: error.1,
traceback: Some(error.2),
}
}
fn new_internal<T: fmt::Display>(message: T) -> Self {
Self {
kind: ErrorKind::PySyncEngineStateError,
message: format!("CRITICAL: PySyncEngine internal error: {}", message),
exception: None,
traceback: None,
}
}
fn new_offline() -> Self {
Self {
kind: ErrorKind::PySyncEngineStateError,
message: "PySyncEngine is offline".to_owned(),
exception: None,
traceback: None,
}
}
fn new_online() -> Self {
Self {
kind: ErrorKind::PySyncEngineStateError,
message: "PySyncEngine is online".to_owned(),
exception: None,
traceback: None,
}
}
}
#[derive(Debug)]
pub struct PyTask {
command: Value,
params: BTreeMap<String, Value>,
need_result: bool,
exclusive: bool,
}
impl PyTask {
#[must_use]
pub fn new(command: Value, params: BTreeMap<String, Value>) -> Self {
Self {
command,
params,
need_result: true,
exclusive: false,
}
}
#[must_use]
pub fn new0(command: Value) -> Self {
Self {
command,
params: BTreeMap::new(),
need_result: true,
exclusive: false,
}
}
pub fn no_wait(&mut self) {
self.need_result = false;
}
pub fn mark_exclusive(&mut self) {
self.exclusive = true;
self.need_result = true;
}
}
struct DataChannel {
tx: Mutex<mpsc::Sender<(u64, Option<PyTask>)>>,
rx: Mutex<mpsc::Receiver<(u64, Option<PyTask>)>>,
}
impl DataChannel {
fn new() -> Self {
let (tx, rx) = mpsc::channel::<(u64, Option<PyTask>)>(DATACHANNEL_DEFAULT_BUFFER);
Self {
tx: Mutex::new(tx),
rx: Mutex::new(rx),
}
}
fn set_buffer(&mut self, buffer: usize) {
let (tx, rx) = mpsc::channel::<(u64, Option<PyTask>)>(buffer);
self.tx = Mutex::new(tx);
self.rx = Mutex::new(rx);
}
}
#[derive(Debug)]
struct PyTaskResult {
task_id: u64,
ready: triggered::Trigger,
result: Option<Value>,
error: Option<Error>,
}
impl PyTaskResult {
#[allow(clippy::redundant_closure)]
fn set_result(&mut self, result: Option<Value>) {
self.result = result;
}
fn set_error(&mut self, error: Error) {
self.error = Some(error);
}
}
struct PyTaskCounter {
id: u64,
}
impl PyTaskCounter {
fn new() -> Self {
Self { id: 0 }
}
fn get(&mut self) -> u64 {
if self.id == std::u64::MAX {
self.id = 1;
} else {
self.id += 1;
}
self.id
}
}
lazy_static! {
static ref PY_RESULTS: RwLock<BTreeMap<u64, PyTaskResult>> = RwLock::new(BTreeMap::new());
static ref TASK_COUNTER: Mutex<PyTaskCounter> = Mutex::new(PyTaskCounter::new());
static ref DC: RwLock<DataChannel> = RwLock::new(DataChannel::new());
}
pub struct PySyncEngine<'p> {
neo: &'p pyo3::types::PyModule,
}
macro_rules! need_online {
() => {
if ENGINE_STATE.load(Ordering::SeqCst) != STATE_STARTED {
return Err(Error::new_offline());
}
};
}
macro_rules! need_offline {
() => {
if ENGINE_STATE.load(Ordering::SeqCst) != STATE_STOPPED {
return Err(Error::new_online());
}
};
}
macro_rules! critical {
($msg: expr) => {
error!("PySyncEngine CRIICAL: {}", $msg);
};
}
macro_rules! log_lost_task {
($task_id: expr) => {
error!("PySyncEngine CRIICAL: task {} is lost", $task_id);
};
}
fn report_error(task_id: u64, error: Error) {
loop {
if let Ok(mut v) = PY_RESULTS.try_write() {
if let Some(o) = v.get_mut(&task_id) {
o.set_error(error);
o.ready.trigger();
break;
}
log_lost_task!(task_id);
return;
}
std::thread::sleep(PIME_POLL_DELAY);
continue;
}
}
#[pyfunction]
fn report_result(
py: Python,
task_id: u64,
result: Option<Py<PyAny>>,
error: Option<(String, String, String)>,
) {
let data: Option<Value> = if let Some(r) = result {
match depythonize(r.as_ref(py)) {
Ok(v) => v,
Err(e) => {
report_error(task_id, Error::new(ErrorKind::UnpackError, e));
return;
}
}
} else {
None
};
loop {
if let Ok(mut v) = PY_RESULTS.try_write() {
if let Some(o) = v.get_mut(&task_id) {
o.set_result(data);
if let Some(e) = error {
o.set_error(Error::new_py(e));
}
o.ready.trigger();
break;
}
log_lost_task!(task_id);
return;
}
std::thread::sleep(PIME_POLL_DELAY);
continue;
}
}
impl<'p> PySyncEngine<'p> {
pub fn new(py: &'p pyo3::Python) -> Result<Self, Error> {
PySyncEngine::new_engine(py, None)
}
pub fn new_venv(py: &'p pyo3::Python, venv_path: &str) -> Result<Self, Error> {
PySyncEngine::new_engine(py, Some(venv_path))
}
#[allow(clippy::trivially_copy_pass_by_ref)]
#[allow(clippy::too_many_lines)]
fn new_engine(py: &'p pyo3::Python, venv_path: Option<&str>) -> Result<Self, Error> {
if ENGINE_STATE.load(Ordering::SeqCst) != STATE_STOPPED {
return Err(Error::new_online());
}
if let Some(dir) = venv_path {
let cfg = format!("{}/pyvenv.cfg", dir);
let ini = match Ini::load_from_file(&cfg) {
Ok(v) => v,
Err(e) => {
return Err(Error::new(
ErrorKind::InternalError,
format!("Unable to read venv config file {}: {}", cfg, e),
));
}
};
let ver_info = py.version_info();
macro_rules! unwrap_ver_err {
($v: expr) => {
match $v {
Ok(v) => v,
Err(e) => {
return Err(Error::new(
ErrorKind::PyException,
format!("Unable to parse venv version info: {}", e),
));
}
}
};
}
macro_rules! unwrap_ver {
($v: expr) => {
if let Some(v) = $v {
v
} else {
return Err(Error::new(
ErrorKind::PyException,
"Unable to get venv version info".to_owned(),
));
}
};
}
let venv_ver = unwrap_ver!(ini.general_section().get("version"));
let mut s = venv_ver.split('.');
let venv_major = unwrap_ver_err!(unwrap_ver!(s.next()).parse::<u8>());
let venv_minor = unwrap_ver_err!(unwrap_ver!(s.next()).parse::<u8>());
if venv_major != ver_info.major || venv_minor != ver_info.minor {
return Err(Error::new(
ErrorKind::PyException,
format!(
"Unable to activate venv, \
Python library version: {}.{}, venv version: {}. \
Please switch the library or rebuild venv",
ver_info.major, ver_info.minor, venv_ver
),
));
}
if let Some(v) = ini.general_section().get("include-system-site-packages") {
if v == "false" {
debug!("Removing system-site packages from Python path");
py.run(
"import sys;list(map(lambda x:sys.path.remove(x) \
if x.endswith('-packages') or '/dist-packages/' in x or \
'/site-packages/' in x else False, sys.path.copy()))",
None,
None,
)?;
}
}
let import_path = format!(
"{}/lib/python{}.{}/site-packages",
dir, ver_info.major, ver_info.minor
);
debug!("Adding Python venv import path: {}", import_path);
py.run(
&format!("import sys;sys.path.insert(0,'{}')", import_path),
None,
None,
)?;
}
let neo = py.import("neotasker.embed")?;
neo.add_function(wrap_pyfunction!(report_result, neo)?)?;
Ok(Self { neo })
}
pub fn add_import_path(&self, dir: &str) -> Result<(), Error> {
self.neo.call_method1("add_import_path", (dir,))?;
Ok(())
}
pub fn enable_debug(&self) -> Result<(), Error> {
self.neo.call_method0("set_debug")?;
Ok(())
}
pub fn set_poll_delay(&self, delay: f32) -> Result<(), Error> {
self.neo.call_method1("set_poll_delay", (delay,))?;
Ok(())
}
pub fn set_thread_pool_size(&self, min: u32, max: u32) -> Result<(), Error> {
self.neo
.call_method1("set_thread_pool_size", ((min, max),))?;
Ok(())
}
pub fn launch(&self, py: &'p pyo3::Python, broker: &pyo3::PyAny) -> Result<(), Error> {
need_offline!();
let dc = DC.try_read()?;
let mut rx = dc.rx.try_lock()?;
self.neo.call_method0("start")?;
let call = self.neo.getattr("call")?;
let call_direct = self.neo.getattr("call_direct")?;
let spawn = self.neo.getattr("spawn")?;
ENGINE_STATE.store(STATE_STARTED, Ordering::SeqCst);
debug!("PySyncEngine started");
loop {
if let Some((task_id, t)) = py.allow_threads(|| rx.blocking_recv()) {
if let Some(task) = t {
let command = match pythonize(*py, &task.command) {
Ok(v) => v,
Err(e) => {
if task_id != 0 {
report_error(task_id, Error::new(ErrorKind::PackError, e));
};
continue;
}
};
let params = match pythonize(*py, &task.params) {
Ok(v) => v,
Err(e) => {
if task_id != 0 {
report_error(task_id, Error::new(ErrorKind::PackError, e));
};
continue;
}
};
if task.exclusive {
if let Err(e) = call_direct.call1((task_id, broker, command, params)) {
report_error(task_id, Error::new(ErrorKind::ExecError, e));
}
} else if task.need_result {
if let Err(e) = call.call1((task_id, broker, command, params)) {
report_error(task_id, Error::new(ErrorKind::ExecError, e));
}
} else {
let _r = spawn.call1((broker, command, params));
}
} else {
ENGINE_STATE.store(STATE_STOPPING, Ordering::SeqCst);
break;
}
} else {
return Err(Error::new_internal("channel broken".to_owned()));
}
}
debug!("Stopping PySyncEngine");
self.neo.call_method0("stop")?;
debug!("PySyncEngine stopped");
ENGINE_STATE.store(STATE_STOPPED, Ordering::SeqCst);
Ok(())
}
}
pub async fn call(task: PyTask) -> Result<Option<Value>, Error> {
need_online!();
if !task.need_result {
DC.read()
.await
.tx
.lock()
.await
.send((0, Some(task)))
.await?;
return Ok(None);
}
let (trigger, listener) = triggered::trigger();
let task_id = loop {
let cid = TASK_COUNTER.lock().await.get();
if let btree_map::Entry::Vacant(x) = PY_RESULTS.write().await.entry(cid) {
x.insert(PyTaskResult {
task_id: cid,
result: None,
error: None,
ready: trigger,
});
break cid;
}
critical!("dead tasks in result map");
};
DC.read()
.await
.tx
.lock()
.await
.send((task_id, Some(task)))
.await?;
listener.await;
PY_RESULTS.write().await.remove(&task_id).map_or_else(
|| {
Err(Error::new(
ErrorKind::InternalError,
"CRITICAL: Result not found, engine broken".to_owned(),
))
},
|res| res.error.map_or(Ok(res.result), Err),
)
}
pub async fn stop() -> Result<(), Error> {
need_online!();
DC.read().await.tx.lock().await.send((0, None)).await?;
wait_offline().await;
Ok(())
}
pub fn get_engine_state() -> u8 {
ENGINE_STATE.load(Ordering::SeqCst)
}
#[must_use]
pub fn is_engine_started() -> bool {
ENGINE_STATE.load(Ordering::SeqCst) == STATE_STARTED
}
pub async fn wait_online() {
while ENGINE_STATE.load(Ordering::SeqCst) != STATE_STARTED {
sleep(PIME_POLL_DELAY).await;
}
}
pub async fn wait_offline() {
while ENGINE_STATE.load(Ordering::SeqCst) != STATE_STOPPED {
sleep(PIME_POLL_DELAY).await;
}
}
pub fn set_mpsc_buffer(buffer: usize) {
DC.try_write().unwrap().set_buffer(buffer);
}