cosmicus 0.1.9

Cosmicus Client and Server to make the runner better
Documentation
mod plugin;
mod runner;

use crate::{
  types::{CancelWorkflowRun, CosmicusRunner, Id, Message, RunWorkflowConfig, RunnerConfig},
  COSMICUS_VERSION,
};
use coodev_runner::{
  Action, CreateWorkflowOptions, GithubAuthorization, Runner, Secret, Volume, Workflow,
};
use futures_util::{SinkExt, StreamExt};
use parking_lot::Mutex;
pub use plugin::ServerPlugin;
use plugin::ServerPluginManager;
pub use runner::run_workflow;
use runner::{create_runner, RunnerCommand};
use std::{collections::HashMap, sync::Arc};
use tokio::{
  net::{TcpListener, TcpStream},
  sync::mpsc,
};
use tokio_tungstenite::{accept_async, tungstenite::Message as WsMessage};

#[derive(Debug, Clone)]
pub struct RunnerState {
  id: Id,
  runner: CosmicusRunner,
  running_runs: u64,
  command_sender: mpsc::Sender<RunnerCommand>,
}

pub struct State {
  runners: HashMap<Id, RunnerState>,
  // workflow_run_id -> runner_id
  runs: HashMap<Id, Id>,
  config: RunnerConfig,
  plugin_manager: ServerPluginManager,
}

#[derive(Clone)]
pub struct CosmicusServer {
  host: String,
  port: u32,
  runner: Runner,
  state: Arc<Mutex<State>>,
}

async fn handle_connection(stream: TcpStream, state: Arc<Mutex<State>>) -> anyhow::Result<()> {
  let mut ws_stream = accept_async(stream).await.map_err(|e| {
    log::error!("Error during the websocket handshake occurred: {}", e);
    anyhow::anyhow!("Error during the websocket handshake occurred: {}", e)
  })?;

  let (sender, mut receiver) = mpsc::channel::<RunnerCommand>(100);

  let mut runner_id: Option<Id> = None;
  loop {
    tokio::select! {
      Some(msg) = ws_stream.next() => {
        // Handle message
        let msg = msg?;
        match msg {
          WsMessage::Text(text) => {
            let message: Message = match serde_json::from_str(&text) {
              Ok(message) => message,
              Err(e) => {
                log::error!("Error parsing message: {}", e);
                continue;
              }
            };

            match message {
              Message::RegisterRunner(runner) => {
                log::info!("Registering runner: {}", runner.id);
                if runner.version != COSMICUS_VERSION {
                  log::error!(
                    "Runner version mismatch: {} != {}",
                    runner.version,
                    COSMICUS_VERSION
                  );

                  if let Err(err) = ws_stream.send(WsMessage::Close(None)).await {
                    log::error!("Error sending close message: {}", err);
                  }
                  continue;
                }
                runner_id = Some(runner.id.clone());

                let msg;
                {
                  let mut state = state.lock();
                  let runner_for_plugin = runner.clone();
                  state.runners.insert(
                    runner.id.clone(),
                    RunnerState {
                      id: runner.id.clone(),
                      runner: CosmicusRunner {
                        id: runner.id,
                        name: runner.name,
                        version: runner.version,
                        max_runs: runner.max_runs,
                      },
                      running_runs: 0,
                      command_sender: sender.clone(),
                    },
                  );
                  state.plugin_manager.on_runner_registered(&runner_for_plugin);

                  let config = state.config.clone();

                  msg = Message::RunnerConfig(config);
                }

                if let Err(err) = ws_stream
                  .send(WsMessage::Text(serde_json::to_string(&msg)?))
                  .await {
                  log::error!("Error sending message: {}", err);
                  }
              }
              Message::WorkflowMessage(msg) => {
                state.lock().plugin_manager.on_event(&msg);
              },
              _ => {}
            }
          }
          WsMessage::Close(_) => {
            log::info!("Closing connection.");
            if runner_id.is_some() {
              state.lock().runners.remove(&runner_id.unwrap());
            }
            break;
          }
          _ => (),
        }
      },
      Some(command) = receiver.recv() => {
        let msg = match command {
          RunnerCommand::Run(run) => Message::RunWorkflow(run),
          RunnerCommand::Cancel(cancel) => Message::CancelWorkflowRun(cancel),
        };
        // Send command
        if let Err(err) = ws_stream.send(WsMessage::Text(serde_json::to_string(&msg)?)).await {
          log::error!("Error sending message: {}", err);
        }
      }
    }
  }

  log::info!("Connection closed.");

  Ok(())
}

impl CosmicusServer {
  pub fn builder() -> CosmicusServerBuilder {
    CosmicusServerBuilder::new()
  }

  pub async fn run(&self) -> anyhow::Result<()> {
    let url = format!("{}:{}", self.host, self.port);

    let listener = TcpListener::bind(&url).await?;
    log::info!("Listening on: {}", url);

    // let mut handles = vec![];
    while let Ok((stream, _)) = listener.accept().await {
      let state = self.state.clone();
      tokio::task::spawn(async move {
        if let Err(e) = handle_connection(stream, state).await {
          log::error!("Error handling connection: {}", e);
        }
      });
    }

    Ok(())
  }

  pub fn parse_workflow(&self, options: CreateWorkflowOptions) -> coodev_runner::Result<Workflow> {
    self.runner.parse_workflow(options)
  }

  pub async fn run_workflow(&self, options: RunWorkflowConfig) -> anyhow::Result<()> {
    let runner = self.get_available_runner().ok_or(anyhow::anyhow!(
      "No available runner found for workflow {}",
      options.id
    ))?;

    {
      let mut state = self.state.lock();

      state.runners.insert(
        runner.id.clone(),
        RunnerState {
          running_runs: runner.running_runs + 1,
          ..runner.clone()
        },
      );

      state.runs.insert(options.id.clone(), runner.id.clone());
    }

    runner
      .command_sender
      .send(RunnerCommand::Run(options.clone()))
      .await?;

    self.state.lock().plugin_manager.on_run_workflow(&options);

    Ok(())
  }

  pub async fn cancel_workflow_run(&self, id: Id) -> anyhow::Result<()> {
    let state = self.state.lock();
    let runner_id = state
      .runs
      .get(&id)
      .ok_or(anyhow::anyhow!("No runner found for workflow run {}", id))?;

    let runner = state
      .runners
      .get(runner_id)
      .ok_or(anyhow::anyhow!("No runner found for id {}", runner_id))?;

    runner
      .command_sender
      .send(RunnerCommand::Cancel(CancelWorkflowRun { id: id.clone() }))
      .await?;

    self
      .state
      .lock()
      .plugin_manager
      .on_cancel_workflow_run(&CancelWorkflowRun { id });

    Ok(())
  }

  pub fn register_secret(&self, secret: Secret) -> &Self {
    self.runner.register_secret(secret.clone());

    let mut state = self.state.lock();
    let secret = secret.into();
    state.config.secrets.push(secret);

    self
  }

  pub fn register_volume(&self, volume: Volume) -> anyhow::Result<&Self> {
    self.runner.register_volume(volume.clone());

    let mut state = self.state.lock();
    let volume = volume
      .try_into()
      .map_err(|e: anyhow::Error| anyhow::anyhow!("Error registering volume: {}", e))?;

    state.config.volumes.push(volume);

    Ok(self)
  }

  pub fn register_action<T>(&self, name: impl Into<String>, action: T) -> &Self
  where
    T: Action + 'static,
  {
    self.runner.register_action(name, action);

    self
  }

  pub fn register_plugin<T>(&self, plugin: T) -> &Self
  where
    T: ServerPlugin + Send + Sync + 'static,
  {
    plugin.on_init(self);
    self.state.lock().plugin_manager.register(Box::new(plugin));

    self
  }

  // TODO: improve this: load balancing, etc.
  pub(crate) fn get_available_runner(&self) -> Option<RunnerState> {
    let state = self.state.lock();
    // load balancing
    let mut runners: Vec<RunnerState> = state
      .runners
      .iter()
      .filter(|(_, runner)| runner.running_runs < runner.runner.max_runs)
      .map(|(_, runner)| runner.clone())
      .collect();

    // sort by running runs
    runners.sort_by_key(|id| id.running_runs);

    // get first runner
    runners.first().cloned()
  }
}

pub struct CosmicusServerBuilder {
  host: Option<String>,
  port: Option<u32>,
  github_personal_access_token: Option<String>,
  github_app_id: Option<u64>,
  github_app_private_key: Option<String>,
}

impl CosmicusServerBuilder {
  pub fn new() -> Self {
    CosmicusServerBuilder {
      port: None,
      host: None,
      github_personal_access_token: None,
      github_app_id: None,
      github_app_private_key: None,
    }
  }

  pub fn port<T>(mut self, port: T) -> Self
  where
    u32: From<T>,
  {
    let port = u32::from(port);
    self.port = Some(port);
    self
  }

  pub fn host(mut self, host: impl Into<String>) -> Self {
    self.host = Some(host.into());
    self
  }

  pub fn github_personal_access_token(mut self, token: impl Into<String>) -> Self {
    self.github_personal_access_token = Some(token.into());
    self
  }

  pub fn github_app_id<T>(mut self, id: T) -> Self
  where
    u64: From<T>,
  {
    let id = u64::from(id);
    self.github_app_id = Some(id);
    self
  }

  pub fn github_app_private_key(mut self, key: impl Into<String>) -> Self {
    self.github_app_private_key = Some(key.into());
    self
  }

  pub fn build(self) -> anyhow::Result<CosmicusServer> {
    let github_authorization;
    if let (Some(app_id), Some(private_key)) = (self.github_app_id, self.github_app_private_key) {
      github_authorization = GithubAuthorization::GithubApp {
        app_id,
        private_key,
      };
    } else if let Some(personal_access_token) = self.github_personal_access_token {
      github_authorization = GithubAuthorization::PersonalAccessToken(personal_access_token);
    } else {
      return Err(anyhow::anyhow!("No github authorization provided"));
    }

    let runner = create_runner(github_authorization.clone())?;

    let port = self.port.unwrap_or(5001);
    let host = self.host.unwrap_or("0.0.0.0".to_string());

    let server = CosmicusServer {
      port,
      host,
      runner,
      state: Arc::new(Mutex::new(State {
        runners: HashMap::new(),
        runs: HashMap::new(),
        config: RunnerConfig {
          github_authorization,
          volumes: vec![],
          secrets: vec![],
        },
        plugin_manager: ServerPluginManager::new(),
      })),
    };

    Ok(server)
  }
}