cog_core/
spec.rs

1use anyhow::Result;
2use core::fmt::Debug;
3use schemars::JsonSchema;
4use serde::{de::DeserializeOwned, Serialize};
5use serde_json::Value;
6use std::future::Future;
7
8use crate::http::Request;
9
10/// A Cog model
11pub trait Cog: Sized + Send {
12	type Request: DeserializeOwned + JsonSchema + Send;
13	type Response: CogResponse + Debug + JsonSchema;
14
15	/// Setup the model
16	///
17	/// # Errors
18	///
19	/// Returns an error if setup fails.
20	fn setup() -> impl Future<Output = Result<Self>> + Send;
21
22	/// Run a prediction on the model
23	///
24	/// # Errors
25	///
26	/// Returns an error if the prediction fails.
27	fn predict(&self, input: Self::Request) -> Result<Self::Response>;
28}
29
30/// A response from a Cog model
31pub trait CogResponse: Send {
32	/// Convert the response into a JSON value
33	fn into_response(self, request: Request) -> impl Future<Output = Result<Value>> + Send;
34}
35
36impl<T: Serialize + Send + 'static> CogResponse for T {
37	async fn into_response(self, _: Request) -> Result<Value> {
38		// We use spawn_blocking here to allow blocking code in serde Serialize impls (used in `Path`, for example).
39		Ok(tokio::task::spawn_blocking(move || serde_json::to_value(self)).await??)
40	}
41}