use crate::{SessionConfig, SessionManager, SessionState, SessionType};
use pyo3::prelude::*;
use pyo3_async_runtimes::tokio::future_into_py;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use super::to_py_err;
#[pyclass(name = "SessionType")]
#[derive(Clone)]
pub struct PySessionType {
inner: SessionType,
}
#[pymethods]
impl PySessionType {
#[classattr]
const STANDARD_STREAM: &'static str = "standard_stream";
#[classattr]
const PORT: &'static str = "port";
#[classattr]
const INTERACTIVE_COMMANDS: &'static str = "interactive_commands";
#[new]
fn new(session_type: &str) -> PyResult<Self> {
let inner = match session_type {
"standard_stream" => SessionType::StandardStream,
"port" => SessionType::Port,
"interactive_commands" => SessionType::InteractiveCommands,
_ => {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"Invalid session type: '{}'. Valid types: 'standard_stream', 'port', 'interactive_commands'",
session_type
)))
}
};
Ok(Self { inner })
}
fn __repr__(&self) -> String {
format!("SessionType({:?})", self.inner)
}
}
#[pyclass(name = "SessionConfig")]
#[derive(Clone)]
pub struct PySessionConfig {
inner: SessionConfig,
}
#[pymethods]
impl PySessionConfig {
#[new]
#[pyo3(signature = (target, region=None, session_type=None, document_name=None, parameters=None, reason=None))]
fn new(
target: String,
region: Option<String>,
session_type: Option<PySessionType>,
document_name: Option<String>,
parameters: Option<HashMap<String, Vec<String>>>,
reason: Option<String>,
) -> Self {
let inner = SessionConfig {
target,
region,
session_type: session_type
.map(|t| t.inner)
.unwrap_or(SessionType::StandardStream),
document_name,
parameters: parameters.unwrap_or_default(),
reason,
..Default::default() };
Self { inner }
}
#[getter]
fn target(&self) -> String {
self.inner.target.clone()
}
#[getter]
fn region(&self) -> Option<String> {
self.inner.region.clone()
}
fn __repr__(&self) -> String {
format!(
"SessionConfig(target='{}', region={:?})",
self.inner.target, self.inner.region
)
}
}
#[pyclass(name = "Session")]
pub struct PySession {
inner: Arc<tokio::sync::Mutex<crate::Session>>,
}
#[pymethods]
impl PySession {
#[getter]
fn id<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
let session = Arc::clone(&self.inner);
future_into_py(py, async move {
let session_guard = session.lock().await;
Ok(session_guard.id().to_string())
})
}
fn state<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
let session = Arc::clone(&self.inner);
future_into_py(py, async move {
let session_guard = session.lock().await;
let state = session_guard.state().await;
let state_str = match state {
SessionState::Initializing => "initializing",
SessionState::Connected => "connected",
SessionState::Disconnecting => "disconnecting",
SessionState::Terminated => "terminated",
};
Ok(state_str)
})
}
fn is_ready<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
let session = Arc::clone(&self.inner);
future_into_py(py, async move {
let session_guard = session.lock().await;
Ok(session_guard.is_ready())
})
}
#[pyo3(signature = (timeout_secs = 30.0))]
fn wait_for_ready<'py>(
&self,
py: Python<'py>,
timeout_secs: f64,
) -> PyResult<Bound<'py, PyAny>> {
let session = Arc::clone(&self.inner);
future_into_py(py, async move {
let timeout = std::time::Duration::from_secs_f64(timeout_secs);
let session_guard = session.lock().await;
Ok(session_guard.wait_for_ready(timeout).await)
})
}
fn output<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
let session = Arc::clone(&self.inner);
future_into_py(py, async move {
let session_guard = session.lock().await;
let stream = session_guard.output();
Ok(PyOutputStream {
inner: Arc::new(tokio::sync::Mutex::new(stream)),
})
})
}
fn send<'py>(&self, py: Python<'py>, data: Vec<u8>) -> PyResult<Bound<'py, PyAny>> {
let session = Arc::clone(&self.inner);
future_into_py(py, async move {
let session_guard = session.lock().await;
session_guard
.send(bytes::Bytes::from(data))
.await
.map_err(to_py_err)?;
Ok(())
})
}
fn terminate<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
let session = Arc::clone(&self.inner);
future_into_py(py, async move {
let mut session_guard = session.lock().await;
session_guard.terminate().await.map_err(to_py_err)?;
Ok(())
})
}
fn wait_terminated<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
let session = Arc::clone(&self.inner);
future_into_py(py, async move {
let session_guard = session.lock().await;
session_guard.wait_terminated().await;
Ok(())
})
}
fn __aenter__<'py>(slf: PyRef<'py, Self>, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
let session = Arc::clone(&slf.inner);
future_into_py(py, async move {
let session_guard = session.lock().await;
let _ = session_guard
.wait_for_ready(std::time::Duration::from_secs(30))
.await;
drop(session_guard);
Ok(PySession { inner: session })
})
}
#[pyo3(signature = (_exc_type=None, _exc_val=None, _exc_tb=None))]
fn __aexit__<'py>(
&self,
py: Python<'py>,
_exc_type: Option<Bound<'_, PyAny>>,
_exc_val: Option<Bound<'_, PyAny>>,
_exc_tb: Option<Bound<'_, PyAny>>,
) -> PyResult<Bound<'py, PyAny>> {
let session = Arc::clone(&self.inner);
future_into_py(py, async move {
let mut session_guard = session.lock().await;
let _ = session_guard.terminate().await;
Ok(false) })
}
fn __repr__(&self) -> String {
"Session()".to_string()
}
}
#[pyclass(name = "SessionManager")]
pub struct PySessionManager {
inner: Arc<RwLock<SessionManager>>,
}
#[pymethods]
impl PySessionManager {
#[staticmethod]
#[allow(clippy::new_ret_no_self)]
fn new(py: Python<'_>) -> PyResult<Bound<'_, PyAny>> {
future_into_py(py, async move {
let manager = SessionManager::new().await.map_err(to_py_err)?;
Ok(PySessionManager {
inner: Arc::new(RwLock::new(manager)),
})
})
}
#[pyo3(signature = (target, region=None, session_type=None, document_name=None, parameters=None, reason=None))]
#[allow(clippy::too_many_arguments)]
fn start_session<'py>(
&self,
py: Python<'py>,
target: String,
region: Option<String>,
session_type: Option<String>,
document_name: Option<String>,
parameters: Option<HashMap<String, Vec<String>>>,
reason: Option<String>,
) -> PyResult<Bound<'py, PyAny>> {
let manager = Arc::clone(&self.inner);
let session_type_enum = if let Some(ref st) = session_type {
match st.as_str() {
"standard_stream" => SessionType::StandardStream,
"port" => SessionType::Port,
"interactive_commands" => SessionType::InteractiveCommands,
_ => {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"Invalid session type: '{}'. Valid types: 'standard_stream', 'port', 'interactive_commands'",
st
)))
}
}
} else {
SessionType::StandardStream
};
future_into_py(py, async move {
let config = SessionConfig {
target,
region,
session_type: session_type_enum,
document_name,
parameters: parameters.unwrap_or_default(),
reason,
..Default::default()
};
let manager_guard = manager.read().await;
let session = manager_guard
.start_session(config)
.await
.map_err(to_py_err)?;
Ok(PySession {
inner: Arc::new(tokio::sync::Mutex::new(session)),
})
})
}
fn terminate_session<'py>(
&self,
py: Python<'py>,
session_id: String,
) -> PyResult<Bound<'py, PyAny>> {
let manager = Arc::clone(&self.inner);
future_into_py(py, async move {
let manager_guard = manager.read().await;
manager_guard
.terminate_session(&session_id)
.await
.map_err(to_py_err)?;
Ok(())
})
}
fn __repr__(&self) -> String {
"SessionManager()".to_string()
}
}
#[pyclass(name = "OutputStream")]
pub struct PyOutputStream {
inner: Arc<tokio::sync::Mutex<crate::OutputStream>>,
}
#[pymethods]
impl PyOutputStream {
fn __aiter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
slf
}
fn __anext__<'py>(&self, py: Python<'py>) -> PyResult<Option<Bound<'py, PyAny>>> {
use futures::StreamExt;
use pyo3::exceptions::PyStopAsyncIteration;
let stream = Arc::clone(&self.inner);
let fut = future_into_py(py, async move {
let mut stream_guard = stream.lock().await;
match stream_guard.next().await {
Some(bytes) => Ok(bytes.to_vec()),
None => Err(PyStopAsyncIteration::new_err(())),
}
})?;
Ok(Some(fut))
}
fn __repr__(&self) -> String {
"OutputStream()".to_string()
}
}