1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
use anyhow::Result;
use async_trait::async_trait;
use core::fmt::Debug;
use schemars::JsonSchema;
use serde::{de::DeserializeOwned, Serialize};
use serde_json::Value;

use crate::http::Request;

#[async_trait]
pub trait Cog: Sized + Send {
	type Request: DeserializeOwned + JsonSchema + Send;
	type Response: CogResponse + Debug + JsonSchema;

	/// Setup the cog
	///
	/// # Errors
	///
	/// Returns an error if setup fails.
	async fn setup() -> Result<Self>;

	/// Run a prediction
	fn predict(&self, input: Self::Request) -> Result<Self::Response>;
}

/// A response from a cog
#[async_trait]
pub trait CogResponse: Send {
	async fn into_response(self, request: Request) -> Result<Value>;
}

#[async_trait]
impl<T: Serialize + Send + 'static> CogResponse for T {
	async fn into_response(self, _: Request) -> Result<Value> {
		// We use spawn_blocking here to allow blocking code in serde Serialize impls (used in `Path`, for example).
		Ok(tokio::task::spawn_blocking(move || serde_json::to_value(self)).await??)
	}
}