cog_rust/
lib.rs

1#![warn(clippy::all, clippy::pedantic, clippy::nursery)]
2
3use anyhow::Result;
4use clap::Parser;
5use tracing_subscriber::{
6	prelude::__tracing_subscriber_SubscriberExt, util::SubscriberInitExt, EnvFilter, Layer,
7};
8
9pub use cog_core::{Cog, CogResponse};
10pub use spec::Path;
11
12mod errors;
13mod helpers;
14mod prediction;
15mod routes;
16mod runner;
17mod server;
18mod shutdown;
19mod spec;
20mod webhooks;
21
22#[derive(Debug, clap::Parser)]
23pub(crate) struct Cli {
24	/// Dump the schema and exit
25	#[clap(long)]
26	dump_schema_and_exit: bool,
27
28	/// Ignore SIGTERM and wait for a request to /shutdown (or a SIGINT) before exiting
29	#[arg(long, default_missing_value = "true", require_equals = true, num_args=0..=1, action = clap::ArgAction::Set)]
30	await_explicit_shutdown: Option<bool>,
31
32	/// An endpoint for Cog to PUT output files to
33	#[clap(long)]
34	upload_url: Option<url::Url>,
35}
36
37/// Start the server with the given model.
38///
39/// # Errors
40///
41/// This function will return an error if the PORT environment variable is set but cannot be parsed, or if the server fails to start.
42pub async fn start<T: Cog + 'static>() -> Result<()> {
43	let args = Cli::parse();
44
45	if !args.dump_schema_and_exit {
46		tracing_subscriber::registry()
47			.with(tracing_subscriber::fmt::layer().with_filter(
48				EnvFilter::try_from_default_env().unwrap_or_else(|_| "cog_rust=info".into()),
49			))
50			.init();
51	}
52
53	server::start::<T>(args).await
54}
55
56#[macro_export]
57/// Start the server with the given model.
58macro_rules! start {
59	($struct_name:ident) => {
60		#[tokio::main]
61		async fn main() {
62			cog_rust::start::<$struct_name>().await.unwrap();
63		}
64	};
65}