mod plugin;
mod runner;
use crate::{
types::{CancelWorkflowRun, CosmicusRunner, Id, Message, RunWorkflowConfig, RunnerConfig},
COSMICUS_VERSION,
};
use coodev_runner::{
Action, CreateWorkflowOptions, GithubAuthorization, Runner, Secret, Volume, Workflow,
};
use futures_util::{SinkExt, StreamExt};
use parking_lot::Mutex;
pub use plugin::ServerPlugin;
use plugin::ServerPluginManager;
pub use runner::run_workflow;
use runner::{create_runner, RunnerCommand};
use std::{collections::HashMap, sync::Arc};
use tokio::{
net::{TcpListener, TcpStream},
sync::mpsc,
};
use tokio_tungstenite::{accept_async, tungstenite::Message as WsMessage};
#[derive(Debug, Clone)]
pub struct RunnerState {
id: Id,
runner: CosmicusRunner,
running_runs: u64,
command_sender: mpsc::Sender<RunnerCommand>,
}
pub struct State {
runners: HashMap<Id, RunnerState>,
runs: HashMap<Id, Id>,
config: RunnerConfig,
plugin_manager: ServerPluginManager,
}
#[derive(Clone)]
pub struct CosmicusServer {
host: String,
port: u32,
runner: Runner,
state: Arc<Mutex<State>>,
}
async fn handle_connection(stream: TcpStream, state: Arc<Mutex<State>>) -> anyhow::Result<()> {
let mut ws_stream = accept_async(stream).await.map_err(|e| {
log::error!("Error during the websocket handshake occurred: {}", e);
anyhow::anyhow!("Error during the websocket handshake occurred: {}", e)
})?;
let (sender, mut receiver) = mpsc::channel::<RunnerCommand>(100);
let mut runner_id: Option<Id> = None;
loop {
tokio::select! {
Some(msg) = ws_stream.next() => {
let msg = msg?;
match msg {
WsMessage::Text(text) => {
let message: Message = match serde_json::from_str(&text) {
Ok(message) => message,
Err(e) => {
log::error!("Error parsing message: {}", e);
continue;
}
};
match message {
Message::RegisterRunner(runner) => {
log::info!("Registering runner: {}", runner.id);
if runner.version != COSMICUS_VERSION {
log::error!(
"Runner version mismatch: {} != {}",
runner.version,
COSMICUS_VERSION
);
if let Err(err) = ws_stream.send(WsMessage::Close(None)).await {
log::error!("Error sending close message: {}", err);
}
continue;
}
runner_id = Some(runner.id.clone());
let msg;
{
let mut state = state.lock();
let runner_for_plugin = runner.clone();
state.runners.insert(
runner.id.clone(),
RunnerState {
id: runner.id.clone(),
runner: CosmicusRunner {
id: runner.id,
name: runner.name,
version: runner.version,
max_runs: runner.max_runs,
},
running_runs: 0,
command_sender: sender.clone(),
},
);
state.plugin_manager.on_runner_registered(&runner_for_plugin);
let config = state.config.clone();
msg = Message::RunnerConfig(config);
}
if let Err(err) = ws_stream
.send(WsMessage::Text(serde_json::to_string(&msg)?))
.await {
log::error!("Error sending message: {}", err);
}
}
Message::WorkflowMessage(msg) => {
state.lock().plugin_manager.on_event(&msg);
},
_ => {}
}
}
WsMessage::Close(_) => {
log::info!("Closing connection.");
if runner_id.is_some() {
state.lock().runners.remove(&runner_id.unwrap());
}
break;
}
_ => (),
}
},
Some(command) = receiver.recv() => {
let msg = match command {
RunnerCommand::Run(run) => Message::RunWorkflow(run),
RunnerCommand::Cancel(cancel) => Message::CancelWorkflowRun(cancel),
};
if let Err(err) = ws_stream.send(WsMessage::Text(serde_json::to_string(&msg)?)).await {
log::error!("Error sending message: {}", err);
}
}
}
}
log::info!("Connection closed.");
Ok(())
}
impl CosmicusServer {
pub fn builder() -> CosmicusServerBuilder {
CosmicusServerBuilder::new()
}
pub async fn run(&self) -> anyhow::Result<()> {
let url = format!("{}:{}", self.host, self.port);
let listener = TcpListener::bind(&url).await?;
log::info!("Listening on: {}", url);
while let Ok((stream, _)) = listener.accept().await {
let state = self.state.clone();
tokio::task::spawn(async move {
if let Err(e) = handle_connection(stream, state).await {
log::error!("Error handling connection: {}", e);
}
});
}
Ok(())
}
pub fn parse_workflow(&self, options: CreateWorkflowOptions) -> coodev_runner::Result<Workflow> {
self.runner.parse_workflow(options)
}
pub async fn run_workflow(&self, options: RunWorkflowConfig) -> anyhow::Result<()> {
let runner = self.get_available_runner().ok_or(anyhow::anyhow!(
"No available runner found for workflow {}",
options.id
))?;
{
let mut state = self.state.lock();
state.runners.insert(
runner.id.clone(),
RunnerState {
running_runs: runner.running_runs + 1,
..runner.clone()
},
);
state.runs.insert(options.id.clone(), runner.id.clone());
}
runner
.command_sender
.send(RunnerCommand::Run(options.clone()))
.await?;
self.state.lock().plugin_manager.on_run_workflow(&options);
Ok(())
}
pub async fn cancel_workflow_run(&self, id: Id) -> anyhow::Result<()> {
let state = self.state.lock();
let runner_id = state
.runs
.get(&id)
.ok_or(anyhow::anyhow!("No runner found for workflow run {}", id))?;
let runner = state
.runners
.get(runner_id)
.ok_or(anyhow::anyhow!("No runner found for id {}", runner_id))?;
runner
.command_sender
.send(RunnerCommand::Cancel(CancelWorkflowRun { id: id.clone() }))
.await?;
self
.state
.lock()
.plugin_manager
.on_cancel_workflow_run(&CancelWorkflowRun { id });
Ok(())
}
pub fn register_secret(&self, secret: Secret) -> &Self {
self.runner.register_secret(secret.clone());
let mut state = self.state.lock();
let secret = secret.into();
state.config.secrets.push(secret);
self
}
pub fn register_volume(&self, volume: Volume) -> anyhow::Result<&Self> {
self.runner.register_volume(volume.clone());
let mut state = self.state.lock();
let volume = volume
.try_into()
.map_err(|e: anyhow::Error| anyhow::anyhow!("Error registering volume: {}", e))?;
state.config.volumes.push(volume);
Ok(self)
}
pub fn register_action<T>(&self, name: impl Into<String>, action: T) -> &Self
where
T: Action + 'static,
{
self.runner.register_action(name, action);
self
}
pub fn register_plugin<T>(&self, plugin: T) -> &Self
where
T: ServerPlugin + Send + Sync + 'static,
{
plugin.on_init(self);
self.state.lock().plugin_manager.register(Box::new(plugin));
self
}
pub(crate) fn get_available_runner(&self) -> Option<RunnerState> {
let state = self.state.lock();
let mut runners: Vec<RunnerState> = state
.runners
.iter()
.filter(|(_, runner)| runner.running_runs < runner.runner.max_runs)
.map(|(_, runner)| runner.clone())
.collect();
runners.sort_by_key(|id| id.running_runs);
runners.first().cloned()
}
}
pub struct CosmicusServerBuilder {
host: Option<String>,
port: Option<u32>,
github_personal_access_token: Option<String>,
github_app_id: Option<u64>,
github_app_private_key: Option<String>,
}
impl CosmicusServerBuilder {
pub fn new() -> Self {
CosmicusServerBuilder {
port: None,
host: None,
github_personal_access_token: None,
github_app_id: None,
github_app_private_key: None,
}
}
pub fn port<T>(mut self, port: T) -> Self
where
u32: From<T>,
{
let port = u32::from(port);
self.port = Some(port);
self
}
pub fn host(mut self, host: impl Into<String>) -> Self {
self.host = Some(host.into());
self
}
pub fn github_personal_access_token(mut self, token: impl Into<String>) -> Self {
self.github_personal_access_token = Some(token.into());
self
}
pub fn github_app_id<T>(mut self, id: T) -> Self
where
u64: From<T>,
{
let id = u64::from(id);
self.github_app_id = Some(id);
self
}
pub fn github_app_private_key(mut self, key: impl Into<String>) -> Self {
self.github_app_private_key = Some(key.into());
self
}
pub fn build(self) -> anyhow::Result<CosmicusServer> {
let github_authorization;
if let (Some(app_id), Some(private_key)) = (self.github_app_id, self.github_app_private_key) {
github_authorization = GithubAuthorization::GithubApp {
app_id,
private_key,
};
} else if let Some(personal_access_token) = self.github_personal_access_token {
github_authorization = GithubAuthorization::PersonalAccessToken(personal_access_token);
} else {
return Err(anyhow::anyhow!("No github authorization provided"));
}
let runner = create_runner(github_authorization.clone())?;
let port = self.port.unwrap_or(5001);
let host = self.host.unwrap_or("0.0.0.0".to_string());
let server = CosmicusServer {
port,
host,
runner,
state: Arc::new(Mutex::new(State {
runners: HashMap::new(),
runs: HashMap::new(),
config: RunnerConfig {
github_authorization,
volumes: vec![],
secrets: vec![],
},
plugin_manager: ServerPluginManager::new(),
})),
};
Ok(server)
}
}