#![allow(clippy::all, unused_imports, dead_code)]
use crate::{Actor, ActorBehavior, Message, Port};
use anyhow::{Error, Result};
use reflow_actor::{message::EncodableValue, ActorContext};
use reflow_actor_macro::actor;
use serde_json::{json, Value};
use std::collections::HashMap;
use std::time::Duration;
const BASE_URL: &str = "https://api.replicate.com/v1";
const ENV_KEY: &str = "REPLICATE_API_KEY";
fn apply_auth(
config: &reflow_actor::ActorConfig,
mut builder: reqwest::RequestBuilder,
) -> Result<reqwest::RequestBuilder> {
let credential = config
.get_config_or_env(ENV_KEY)
.ok_or_else(|| anyhow::anyhow!("Missing env var: {}", ENV_KEY))?;
builder = builder.header("Authorization", format!("Bearer {}", credential));
Ok(builder)
}
#[actor(
ReplicateCreatePredictionActor,
inports::<100>(version, input),
outports::<50>(response, error),
state(MemoryState)
)]
pub async fn replicate_create_prediction(
context: ActorContext,
) -> Result<HashMap<String, Message>, Error> {
let inputs = context.get_payload();
let actor_config = context.get_config();
let endpoint = "/predictions".to_string();
let url = format!("{}{}", BASE_URL.trim_end_matches('/'), endpoint);
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.build()?;
let mut builder = client.post(&url);
builder = builder.header("Content-Type", "application/json");
builder = apply_auth(actor_config, builder)?;
let mut body = serde_json::Map::new();
if let Some(val) = inputs.get("version") {
body.insert("version".to_string(), val.clone().into());
}
if let Some(val) = inputs.get("input") {
body.insert("input".to_string(), val.clone().into());
}
if !body.is_empty() {
builder = builder.json(&serde_json::Value::Object(body));
}
let mut output = HashMap::new();
match builder.send().await {
Ok(resp) => {
let status = resp.status().as_u16();
let headers: HashMap<String, String> = resp
.headers()
.iter()
.filter_map(|(k, v)| v.to_str().ok().map(|val| (k.to_string(), val.to_string())))
.collect();
let body_text = resp.text().await.unwrap_or_default();
let body_value: Value =
serde_json::from_str(&body_text).unwrap_or(Value::String(body_text));
output.insert(
"response".to_string(),
Message::object(EncodableValue::from(json!({
"status": status,
"headers": headers,
"body": body_value,
}))),
);
}
Err(e) => {
output.insert(
"error".to_string(),
Message::Error(format!("POST /predictions failed: {}", e).into()),
);
}
}
Ok(output)
}
#[actor(
ReplicateReadPredictionActor,
inports::<100>(id),
outports::<50>(response, error),
state(MemoryState)
)]
pub async fn replicate_read_prediction(
context: ActorContext,
) -> Result<HashMap<String, Message>, Error> {
let inputs = context.get_payload();
let actor_config = context.get_config();
let mut endpoint = "/predictions/{id}".to_string();
if let Some(val) = inputs.get("id") {
endpoint = endpoint.replace("{{id}}", &super::message_to_str(val));
}
let url = format!("{}{}", BASE_URL.trim_end_matches('/'), endpoint);
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.build()?;
let mut builder = client.get(&url);
builder = builder.header("Content-Type", "application/json");
builder = apply_auth(actor_config, builder)?;
let mut output = HashMap::new();
match builder.send().await {
Ok(resp) => {
let status = resp.status().as_u16();
let headers: HashMap<String, String> = resp
.headers()
.iter()
.filter_map(|(k, v)| v.to_str().ok().map(|val| (k.to_string(), val.to_string())))
.collect();
let body_text = resp.text().await.unwrap_or_default();
let body_value: Value =
serde_json::from_str(&body_text).unwrap_or(Value::String(body_text));
output.insert(
"response".to_string(),
Message::object(EncodableValue::from(json!({
"status": status,
"headers": headers,
"body": body_value,
}))),
);
}
Err(e) => {
output.insert(
"error".to_string(),
Message::Error(format!("GET /predictions/{{id}} failed: {}", e).into()),
);
}
}
Ok(output)
}
#[actor(
ReplicateListAccountActor,
inports::<100>(trigger),
outports::<50>(response, error),
state(MemoryState)
)]
pub async fn replicate_list_account(
context: ActorContext,
) -> Result<HashMap<String, Message>, Error> {
let inputs = context.get_payload();
let actor_config = context.get_config();
let endpoint = "/account".to_string();
let url = format!("{}{}", BASE_URL.trim_end_matches('/'), endpoint);
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.build()?;
let mut builder = client.get(&url);
builder = builder.header("Content-Type", "application/json");
builder = apply_auth(actor_config, builder)?;
let mut output = HashMap::new();
match builder.send().await {
Ok(resp) => {
let status = resp.status().as_u16();
let headers: HashMap<String, String> = resp
.headers()
.iter()
.filter_map(|(k, v)| v.to_str().ok().map(|val| (k.to_string(), val.to_string())))
.collect();
let body_text = resp.text().await.unwrap_or_default();
let body_value: Value =
serde_json::from_str(&body_text).unwrap_or(Value::String(body_text));
output.insert(
"response".to_string(),
Message::object(EncodableValue::from(json!({
"status": status,
"headers": headers,
"body": body_value,
}))),
);
}
Err(e) => {
output.insert(
"error".to_string(),
Message::Error(format!("GET /account failed: {}", e).into()),
);
}
}
Ok(output)
}
#[actor(
ReplicateListCollectionsActor,
inports::<100>(trigger),
outports::<50>(response, error),
state(MemoryState)
)]
pub async fn replicate_list_collections(
context: ActorContext,
) -> Result<HashMap<String, Message>, Error> {
let inputs = context.get_payload();
let actor_config = context.get_config();
let endpoint = "/collections".to_string();
let url = format!("{}{}", BASE_URL.trim_end_matches('/'), endpoint);
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.build()?;
let mut builder = client.get(&url);
builder = builder.header("Content-Type", "application/json");
builder = apply_auth(actor_config, builder)?;
let mut output = HashMap::new();
match builder.send().await {
Ok(resp) => {
let status = resp.status().as_u16();
let headers: HashMap<String, String> = resp
.headers()
.iter()
.filter_map(|(k, v)| v.to_str().ok().map(|val| (k.to_string(), val.to_string())))
.collect();
let body_text = resp.text().await.unwrap_or_default();
let body_value: Value =
serde_json::from_str(&body_text).unwrap_or(Value::String(body_text));
output.insert(
"response".to_string(),
Message::object(EncodableValue::from(json!({
"status": status,
"headers": headers,
"body": body_value,
}))),
);
}
Err(e) => {
output.insert(
"error".to_string(),
Message::Error(format!("GET /collections failed: {}", e).into()),
);
}
}
Ok(output)
}
#[actor(
ReplicateReadCollectionsActor,
inports::<100>(collection_slug),
outports::<50>(response, error),
state(MemoryState)
)]
pub async fn replicate_read_collections(
context: ActorContext,
) -> Result<HashMap<String, Message>, Error> {
let inputs = context.get_payload();
let actor_config = context.get_config();
let mut endpoint = "/collections/{collection_slug}".to_string();
if let Some(val) = inputs.get("collection_slug") {
endpoint = endpoint.replace("{{collection_slug}}", &super::message_to_str(val));
}
let url = format!("{}{}", BASE_URL.trim_end_matches('/'), endpoint);
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.build()?;
let mut builder = client.get(&url);
builder = builder.header("Content-Type", "application/json");
builder = apply_auth(actor_config, builder)?;
let mut output = HashMap::new();
match builder.send().await {
Ok(resp) => {
let status = resp.status().as_u16();
let headers: HashMap<String, String> = resp
.headers()
.iter()
.filter_map(|(k, v)| v.to_str().ok().map(|val| (k.to_string(), val.to_string())))
.collect();
let body_text = resp.text().await.unwrap_or_default();
let body_value: Value =
serde_json::from_str(&body_text).unwrap_or(Value::String(body_text));
output.insert(
"response".to_string(),
Message::object(EncodableValue::from(json!({
"status": status,
"headers": headers,
"body": body_value,
}))),
);
}
Err(e) => {
output.insert(
"error".to_string(),
Message::Error(
format!("GET /collections/{{collection_slug}} failed: {}", e).into(),
),
);
}
}
Ok(output)
}
#[actor(
ReplicateListDeploymentsActor,
inports::<100>(trigger),
outports::<50>(response, error),
state(MemoryState)
)]
pub async fn replicate_list_deployments(
context: ActorContext,
) -> Result<HashMap<String, Message>, Error> {
let inputs = context.get_payload();
let actor_config = context.get_config();
let endpoint = "/deployments".to_string();
let url = format!("{}{}", BASE_URL.trim_end_matches('/'), endpoint);
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.build()?;
let mut builder = client.get(&url);
builder = builder.header("Content-Type", "application/json");
builder = apply_auth(actor_config, builder)?;
let mut output = HashMap::new();
match builder.send().await {
Ok(resp) => {
let status = resp.status().as_u16();
let headers: HashMap<String, String> = resp
.headers()
.iter()
.filter_map(|(k, v)| v.to_str().ok().map(|val| (k.to_string(), val.to_string())))
.collect();
let body_text = resp.text().await.unwrap_or_default();
let body_value: Value =
serde_json::from_str(&body_text).unwrap_or(Value::String(body_text));
output.insert(
"response".to_string(),
Message::object(EncodableValue::from(json!({
"status": status,
"headers": headers,
"body": body_value,
}))),
);
}
Err(e) => {
output.insert(
"error".to_string(),
Message::Error(format!("GET /deployments failed: {}", e).into()),
);
}
}
Ok(output)
}
#[actor(
ReplicateCreateDeploymentsActor,
inports::<100>(name, version, hardware, max_instances, min_instances, model),
outports::<50>(response, error),
state(MemoryState)
)]
pub async fn replicate_create_deployments(
context: ActorContext,
) -> Result<HashMap<String, Message>, Error> {
let inputs = context.get_payload();
let actor_config = context.get_config();
let endpoint = "/deployments".to_string();
let url = format!("{}{}", BASE_URL.trim_end_matches('/'), endpoint);
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.build()?;
let mut builder = client.post(&url);
builder = builder.header("Content-Type", "application/json");
builder = apply_auth(actor_config, builder)?;
let mut body = serde_json::Map::new();
if let Some(val) = inputs.get("name") {
body.insert("name".to_string(), val.clone().into());
}
if let Some(val) = inputs.get("version") {
body.insert("version".to_string(), val.clone().into());
}
if let Some(val) = inputs.get("hardware") {
body.insert("hardware".to_string(), val.clone().into());
}
if let Some(val) = inputs.get("max_instances") {
body.insert("max_instances".to_string(), val.clone().into());
}
if let Some(val) = inputs.get("min_instances") {
body.insert("min_instances".to_string(), val.clone().into());
}
if let Some(val) = inputs.get("model") {
body.insert("model".to_string(), val.clone().into());
}
if !body.is_empty() {
builder = builder.json(&serde_json::Value::Object(body));
}
let mut output = HashMap::new();
match builder.send().await {
Ok(resp) => {
let status = resp.status().as_u16();
let headers: HashMap<String, String> = resp
.headers()
.iter()
.filter_map(|(k, v)| v.to_str().ok().map(|val| (k.to_string(), val.to_string())))
.collect();
let body_text = resp.text().await.unwrap_or_default();
let body_value: Value =
serde_json::from_str(&body_text).unwrap_or(Value::String(body_text));
output.insert(
"response".to_string(),
Message::object(EncodableValue::from(json!({
"status": status,
"headers": headers,
"body": body_value,
}))),
);
}
Err(e) => {
output.insert(
"error".to_string(),
Message::Error(format!("POST /deployments failed: {}", e).into()),
);
}
}
Ok(output)
}
#[actor(
ReplicateReadDeploymentsActor,
inports::<100>(deployment_owner, deployment_name),
outports::<50>(response, error),
state(MemoryState)
)]
pub async fn replicate_read_deployments(
context: ActorContext,
) -> Result<HashMap<String, Message>, Error> {
let inputs = context.get_payload();
let actor_config = context.get_config();
let mut endpoint = "/deployments/{deployment_owner}/{deployment_name}".to_string();
if let Some(val) = inputs.get("deployment_owner") {
endpoint = endpoint.replace("{{deployment_owner}}", &super::message_to_str(val));
}
if let Some(val) = inputs.get("deployment_name") {
endpoint = endpoint.replace("{{deployment_name}}", &super::message_to_str(val));
}
let url = format!("{}{}", BASE_URL.trim_end_matches('/'), endpoint);
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.build()?;
let mut builder = client.get(&url);
builder = builder.header("Content-Type", "application/json");
builder = apply_auth(actor_config, builder)?;
let mut output = HashMap::new();
match builder.send().await {
Ok(resp) => {
let status = resp.status().as_u16();
let headers: HashMap<String, String> = resp
.headers()
.iter()
.filter_map(|(k, v)| v.to_str().ok().map(|val| (k.to_string(), val.to_string())))
.collect();
let body_text = resp.text().await.unwrap_or_default();
let body_value: Value =
serde_json::from_str(&body_text).unwrap_or(Value::String(body_text));
output.insert(
"response".to_string(),
Message::object(EncodableValue::from(json!({
"status": status,
"headers": headers,
"body": body_value,
}))),
);
}
Err(e) => {
output.insert(
"error".to_string(),
Message::Error(
format!(
"GET /deployments/{{deployment_owner}}/{{deployment_name}} failed: {}",
e
)
.into(),
),
);
}
}
Ok(output)
}
#[actor(
ReplicateUpdateDeploymentsActor,
inports::<100>(deployment_owner, deployment_name, version, max_instances, hardware, min_instances),
outports::<50>(response, error),
state(MemoryState)
)]
pub async fn replicate_update_deployments(
context: ActorContext,
) -> Result<HashMap<String, Message>, Error> {
let inputs = context.get_payload();
let actor_config = context.get_config();
let mut endpoint = "/deployments/{deployment_owner}/{deployment_name}".to_string();
if let Some(val) = inputs.get("deployment_owner") {
endpoint = endpoint.replace("{{deployment_owner}}", &super::message_to_str(val));
}
if let Some(val) = inputs.get("deployment_name") {
endpoint = endpoint.replace("{{deployment_name}}", &super::message_to_str(val));
}
let url = format!("{}{}", BASE_URL.trim_end_matches('/'), endpoint);
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.build()?;
let mut builder = client.patch(&url);
builder = builder.header("Content-Type", "application/json");
builder = apply_auth(actor_config, builder)?;
let mut body = serde_json::Map::new();
if let Some(val) = inputs.get("version") {
body.insert("version".to_string(), val.clone().into());
}
if let Some(val) = inputs.get("max_instances") {
body.insert("max_instances".to_string(), val.clone().into());
}
if let Some(val) = inputs.get("hardware") {
body.insert("hardware".to_string(), val.clone().into());
}
if let Some(val) = inputs.get("min_instances") {
body.insert("min_instances".to_string(), val.clone().into());
}
if !body.is_empty() {
builder = builder.json(&serde_json::Value::Object(body));
}
let mut output = HashMap::new();
match builder.send().await {
Ok(resp) => {
let status = resp.status().as_u16();
let headers: HashMap<String, String> = resp
.headers()
.iter()
.filter_map(|(k, v)| v.to_str().ok().map(|val| (k.to_string(), val.to_string())))
.collect();
let body_text = resp.text().await.unwrap_or_default();
let body_value: Value =
serde_json::from_str(&body_text).unwrap_or(Value::String(body_text));
output.insert(
"response".to_string(),
Message::object(EncodableValue::from(json!({
"status": status,
"headers": headers,
"body": body_value,
}))),
);
}
Err(e) => {
output.insert(
"error".to_string(),
Message::Error(
format!(
"PATCH /deployments/{{deployment_owner}}/{{deployment_name}} failed: {}",
e
)
.into(),
),
);
}
}
Ok(output)
}
#[actor(
ReplicateDeleteDeploymentsActor,
inports::<100>(deployment_owner, deployment_name),
outports::<50>(response, error),
state(MemoryState)
)]
pub async fn replicate_delete_deployments(
context: ActorContext,
) -> Result<HashMap<String, Message>, Error> {
let inputs = context.get_payload();
let actor_config = context.get_config();
let mut endpoint = "/deployments/{deployment_owner}/{deployment_name}".to_string();
if let Some(val) = inputs.get("deployment_owner") {
endpoint = endpoint.replace("{{deployment_owner}}", &super::message_to_str(val));
}
if let Some(val) = inputs.get("deployment_name") {
endpoint = endpoint.replace("{{deployment_name}}", &super::message_to_str(val));
}
let url = format!("{}{}", BASE_URL.trim_end_matches('/'), endpoint);
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.build()?;
let mut builder = client.delete(&url);
builder = builder.header("Content-Type", "application/json");
builder = apply_auth(actor_config, builder)?;
let mut output = HashMap::new();
match builder.send().await {
Ok(resp) => {
let status = resp.status().as_u16();
let headers: HashMap<String, String> = resp
.headers()
.iter()
.filter_map(|(k, v)| v.to_str().ok().map(|val| (k.to_string(), val.to_string())))
.collect();
let body_text = resp.text().await.unwrap_or_default();
let body_value: Value =
serde_json::from_str(&body_text).unwrap_or(Value::String(body_text));
output.insert(
"response".to_string(),
Message::object(EncodableValue::from(json!({
"status": status,
"headers": headers,
"body": body_value,
}))),
);
}
Err(e) => {
output.insert(
"error".to_string(),
Message::Error(
format!(
"DELETE /deployments/{{deployment_owner}}/{{deployment_name}} failed: {}",
e
)
.into(),
),
);
}
}
Ok(output)
}
#[actor(
ReplicateCreatePredictionsActor,
inports::<100>(deployment_owner, deployment_name, input, webhook, webhook_events_filter, stream),
outports::<50>(response, error),
state(MemoryState)
)]
pub async fn replicate_create_predictions(
context: ActorContext,
) -> Result<HashMap<String, Message>, Error> {
let inputs = context.get_payload();
let actor_config = context.get_config();
let mut endpoint = "/deployments/{deployment_owner}/{deployment_name}/predictions".to_string();
if let Some(val) = inputs.get("deployment_owner") {
endpoint = endpoint.replace("{{deployment_owner}}", &super::message_to_str(val));
}
if let Some(val) = inputs.get("deployment_name") {
endpoint = endpoint.replace("{{deployment_name}}", &super::message_to_str(val));
}
let url = format!("{}{}", BASE_URL.trim_end_matches('/'), endpoint);
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.build()?;
let mut builder = client.post(&url);
builder = builder.header("Content-Type", "application/json");
builder = apply_auth(actor_config, builder)?;
let mut body = serde_json::Map::new();
if let Some(val) = inputs.get("input") {
body.insert("input".to_string(), val.clone().into());
}
if let Some(val) = inputs.get("webhook") {
body.insert("webhook".to_string(), val.clone().into());
}
if let Some(val) = inputs.get("webhook_events_filter") {
body.insert("webhook_events_filter".to_string(), val.clone().into());
}
if let Some(val) = inputs.get("stream") {
body.insert("stream".to_string(), val.clone().into());
}
if !body.is_empty() {
builder = builder.json(&serde_json::Value::Object(body));
}
let mut output = HashMap::new();
match builder.send().await {
Ok(resp) => {
let status = resp.status().as_u16();
let headers: HashMap<String, String> = resp
.headers()
.iter()
.filter_map(|(k, v)| v.to_str().ok().map(|val| (k.to_string(), val.to_string())))
.collect();
let body_text = resp.text().await.unwrap_or_default();
let body_value: Value =
serde_json::from_str(&body_text).unwrap_or(Value::String(body_text));
output.insert(
"response".to_string(),
Message::object(EncodableValue::from(json!({
"status": status,
"headers": headers,
"body": body_value,
}))),
);
}
Err(e) => {
output.insert("error".to_string(), Message::Error(format!("POST /deployments/{{deployment_owner}}/{{deployment_name}}/predictions failed: {}", e).into()));
}
}
Ok(output)
}
#[actor(
ReplicateListFilesActor,
inports::<100>(trigger),
outports::<50>(response, error),
state(MemoryState)
)]
pub async fn replicate_list_files(
context: ActorContext,
) -> Result<HashMap<String, Message>, Error> {
let inputs = context.get_payload();
let actor_config = context.get_config();
let endpoint = "/files".to_string();
let url = format!("{}{}", BASE_URL.trim_end_matches('/'), endpoint);
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.build()?;
let mut builder = client.get(&url);
builder = builder.header("Content-Type", "application/json");
builder = apply_auth(actor_config, builder)?;
let mut output = HashMap::new();
match builder.send().await {
Ok(resp) => {
let status = resp.status().as_u16();
let headers: HashMap<String, String> = resp
.headers()
.iter()
.filter_map(|(k, v)| v.to_str().ok().map(|val| (k.to_string(), val.to_string())))
.collect();
let body_text = resp.text().await.unwrap_or_default();
let body_value: Value =
serde_json::from_str(&body_text).unwrap_or(Value::String(body_text));
output.insert(
"response".to_string(),
Message::object(EncodableValue::from(json!({
"status": status,
"headers": headers,
"body": body_value,
}))),
);
}
Err(e) => {
output.insert(
"error".to_string(),
Message::Error(format!("GET /files failed: {}", e).into()),
);
}
}
Ok(output)
}
#[actor(
ReplicateCreateFilesActor,
inports::<100>(metadata, content, type_, filename),
outports::<50>(response, error),
state(MemoryState)
)]
pub async fn replicate_create_files(
context: ActorContext,
) -> Result<HashMap<String, Message>, Error> {
let inputs = context.get_payload();
let actor_config = context.get_config();
let endpoint = "/files".to_string();
let url = format!("{}{}", BASE_URL.trim_end_matches('/'), endpoint);
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.build()?;
let mut builder = client.post(&url);
builder = builder.header("Content-Type", "application/json");
builder = apply_auth(actor_config, builder)?;
let mut body = serde_json::Map::new();
if let Some(val) = inputs.get("metadata") {
body.insert("metadata".to_string(), val.clone().into());
}
if let Some(val) = inputs.get("content") {
body.insert("content".to_string(), val.clone().into());
}
if let Some(val) = inputs.get("type_") {
body.insert("type".to_string(), val.clone().into());
}
if let Some(val) = inputs.get("filename") {
body.insert("filename".to_string(), val.clone().into());
}
if !body.is_empty() {
builder = builder.json(&serde_json::Value::Object(body));
}
let mut output = HashMap::new();
match builder.send().await {
Ok(resp) => {
let status = resp.status().as_u16();
let headers: HashMap<String, String> = resp
.headers()
.iter()
.filter_map(|(k, v)| v.to_str().ok().map(|val| (k.to_string(), val.to_string())))
.collect();
let body_text = resp.text().await.unwrap_or_default();
let body_value: Value =
serde_json::from_str(&body_text).unwrap_or(Value::String(body_text));
output.insert(
"response".to_string(),
Message::object(EncodableValue::from(json!({
"status": status,
"headers": headers,
"body": body_value,
}))),
);
}
Err(e) => {
output.insert(
"error".to_string(),
Message::Error(format!("POST /files failed: {}", e).into()),
);
}
}
Ok(output)
}
#[actor(
ReplicateReadFilesActor,
inports::<100>(file_id),
outports::<50>(response, error),
state(MemoryState)
)]
pub async fn replicate_read_files(
context: ActorContext,
) -> Result<HashMap<String, Message>, Error> {
let inputs = context.get_payload();
let actor_config = context.get_config();
let mut endpoint = "/files/{file_id}".to_string();
if let Some(val) = inputs.get("file_id") {
endpoint = endpoint.replace("{{file_id}}", &super::message_to_str(val));
}
let url = format!("{}{}", BASE_URL.trim_end_matches('/'), endpoint);
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.build()?;
let mut builder = client.get(&url);
builder = builder.header("Content-Type", "application/json");
builder = apply_auth(actor_config, builder)?;
let mut output = HashMap::new();
match builder.send().await {
Ok(resp) => {
let status = resp.status().as_u16();
let headers: HashMap<String, String> = resp
.headers()
.iter()
.filter_map(|(k, v)| v.to_str().ok().map(|val| (k.to_string(), val.to_string())))
.collect();
let body_text = resp.text().await.unwrap_or_default();
let body_value: Value =
serde_json::from_str(&body_text).unwrap_or(Value::String(body_text));
output.insert(
"response".to_string(),
Message::object(EncodableValue::from(json!({
"status": status,
"headers": headers,
"body": body_value,
}))),
);
}
Err(e) => {
output.insert(
"error".to_string(),
Message::Error(format!("GET /files/{{file_id}} failed: {}", e).into()),
);
}
}
Ok(output)
}
#[actor(
ReplicateDeleteFilesActor,
inports::<100>(file_id),
outports::<50>(response, error),
state(MemoryState)
)]
pub async fn replicate_delete_files(
context: ActorContext,
) -> Result<HashMap<String, Message>, Error> {
let inputs = context.get_payload();
let actor_config = context.get_config();
let mut endpoint = "/files/{file_id}".to_string();
if let Some(val) = inputs.get("file_id") {
endpoint = endpoint.replace("{{file_id}}", &super::message_to_str(val));
}
let url = format!("{}{}", BASE_URL.trim_end_matches('/'), endpoint);
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.build()?;
let mut builder = client.delete(&url);
builder = builder.header("Content-Type", "application/json");
builder = apply_auth(actor_config, builder)?;
let mut output = HashMap::new();
match builder.send().await {
Ok(resp) => {
let status = resp.status().as_u16();
let headers: HashMap<String, String> = resp
.headers()
.iter()
.filter_map(|(k, v)| v.to_str().ok().map(|val| (k.to_string(), val.to_string())))
.collect();
let body_text = resp.text().await.unwrap_or_default();
let body_value: Value =
serde_json::from_str(&body_text).unwrap_or(Value::String(body_text));
output.insert(
"response".to_string(),
Message::object(EncodableValue::from(json!({
"status": status,
"headers": headers,
"body": body_value,
}))),
);
}
Err(e) => {
output.insert(
"error".to_string(),
Message::Error(format!("DELETE /files/{{file_id}} failed: {}", e).into()),
);
}
}
Ok(output)
}
#[actor(
ReplicateDownloadDownloadActor,
inports::<100>(file_id, owner, expiry, signature),
outports::<50>(response, error),
state(MemoryState)
)]
pub async fn replicate_download_download(
context: ActorContext,
) -> Result<HashMap<String, Message>, Error> {
let inputs = context.get_payload();
let actor_config = context.get_config();
let mut endpoint = "/files/{file_id}/download".to_string();
if let Some(val) = inputs.get("file_id") {
endpoint = endpoint.replace("{{file_id}}", &super::message_to_str(val));
}
let url = format!("{}{}", BASE_URL.trim_end_matches('/'), endpoint);
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.build()?;
let mut builder = client.get(&url);
builder = builder.header("Content-Type", "application/json");
builder = apply_auth(actor_config, builder)?;
let mut query_pairs: Vec<(&str, String)> = Vec::new();
if let Some(val) = inputs.get("owner") {
query_pairs.push(("owner", super::message_to_str(val)));
}
if let Some(val) = inputs.get("expiry") {
query_pairs.push(("expiry", super::message_to_str(val)));
}
if let Some(val) = inputs.get("signature") {
query_pairs.push(("signature", super::message_to_str(val)));
}
if !query_pairs.is_empty() {
builder = builder.query(&query_pairs);
}
let mut output = HashMap::new();
match builder.send().await {
Ok(resp) => {
let status = resp.status().as_u16();
let headers: HashMap<String, String> = resp
.headers()
.iter()
.filter_map(|(k, v)| v.to_str().ok().map(|val| (k.to_string(), val.to_string())))
.collect();
let body_text = resp.text().await.unwrap_or_default();
let body_value: Value =
serde_json::from_str(&body_text).unwrap_or(Value::String(body_text));
output.insert(
"response".to_string(),
Message::object(EncodableValue::from(json!({
"status": status,
"headers": headers,
"body": body_value,
}))),
);
}
Err(e) => {
output.insert(
"error".to_string(),
Message::Error(format!("GET /files/{{file_id}}/download failed: {}", e).into()),
);
}
}
Ok(output)
}
#[actor(
ReplicateListHardwareActor,
inports::<100>(trigger),
outports::<50>(response, error),
state(MemoryState)
)]
pub async fn replicate_list_hardware(
context: ActorContext,
) -> Result<HashMap<String, Message>, Error> {
let inputs = context.get_payload();
let actor_config = context.get_config();
let endpoint = "/hardware".to_string();
let url = format!("{}{}", BASE_URL.trim_end_matches('/'), endpoint);
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.build()?;
let mut builder = client.get(&url);
builder = builder.header("Content-Type", "application/json");
builder = apply_auth(actor_config, builder)?;
let mut output = HashMap::new();
match builder.send().await {
Ok(resp) => {
let status = resp.status().as_u16();
let headers: HashMap<String, String> = resp
.headers()
.iter()
.filter_map(|(k, v)| v.to_str().ok().map(|val| (k.to_string(), val.to_string())))
.collect();
let body_text = resp.text().await.unwrap_or_default();
let body_value: Value =
serde_json::from_str(&body_text).unwrap_or(Value::String(body_text));
output.insert(
"response".to_string(),
Message::object(EncodableValue::from(json!({
"status": status,
"headers": headers,
"body": body_value,
}))),
);
}
Err(e) => {
output.insert(
"error".to_string(),
Message::Error(format!("GET /hardware failed: {}", e).into()),
);
}
}
Ok(output)
}
#[actor(
ReplicateListModelsActor,
inports::<100>(sort_by, sort_direction),
outports::<50>(response, error),
state(MemoryState)
)]
pub async fn replicate_list_models(
context: ActorContext,
) -> Result<HashMap<String, Message>, Error> {
let inputs = context.get_payload();
let actor_config = context.get_config();
let endpoint = "/models".to_string();
let url = format!("{}{}", BASE_URL.trim_end_matches('/'), endpoint);
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.build()?;
let mut builder = client.get(&url);
builder = builder.header("Content-Type", "application/json");
builder = apply_auth(actor_config, builder)?;
let mut query_pairs: Vec<(&str, String)> = Vec::new();
if let Some(val) = inputs.get("sort_by") {
query_pairs.push(("sort_by", super::message_to_str(val)));
}
if let Some(val) = inputs.get("sort_direction") {
query_pairs.push(("sort_direction", super::message_to_str(val)));
}
if !query_pairs.is_empty() {
builder = builder.query(&query_pairs);
}
let mut output = HashMap::new();
match builder.send().await {
Ok(resp) => {
let status = resp.status().as_u16();
let headers: HashMap<String, String> = resp
.headers()
.iter()
.filter_map(|(k, v)| v.to_str().ok().map(|val| (k.to_string(), val.to_string())))
.collect();
let body_text = resp.text().await.unwrap_or_default();
let body_value: Value =
serde_json::from_str(&body_text).unwrap_or(Value::String(body_text));
output.insert(
"response".to_string(),
Message::object(EncodableValue::from(json!({
"status": status,
"headers": headers,
"body": body_value,
}))),
);
}
Err(e) => {
output.insert(
"error".to_string(),
Message::Error(format!("GET /models failed: {}", e).into()),
);
}
}
Ok(output)
}
#[actor(
ReplicateCreateModelsActor,
inports::<100>(owner, cover_image_url, name, hardware, description, visibility, paper_url, github_url, license_url),
outports::<50>(response, error),
state(MemoryState)
)]
pub async fn replicate_create_models(
context: ActorContext,
) -> Result<HashMap<String, Message>, Error> {
let inputs = context.get_payload();
let actor_config = context.get_config();
let endpoint = "/models".to_string();
let url = format!("{}{}", BASE_URL.trim_end_matches('/'), endpoint);
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.build()?;
let mut builder = client.post(&url);
builder = builder.header("Content-Type", "application/json");
builder = apply_auth(actor_config, builder)?;
let mut body = serde_json::Map::new();
if let Some(val) = inputs.get("owner") {
body.insert("owner".to_string(), val.clone().into());
}
if let Some(val) = inputs.get("cover_image_url") {
body.insert("cover_image_url".to_string(), val.clone().into());
}
if let Some(val) = inputs.get("name") {
body.insert("name".to_string(), val.clone().into());
}
if let Some(val) = inputs.get("hardware") {
body.insert("hardware".to_string(), val.clone().into());
}
if let Some(val) = inputs.get("description") {
body.insert("description".to_string(), val.clone().into());
}
if let Some(val) = inputs.get("visibility") {
body.insert("visibility".to_string(), val.clone().into());
}
if let Some(val) = inputs.get("paper_url") {
body.insert("paper_url".to_string(), val.clone().into());
}
if let Some(val) = inputs.get("github_url") {
body.insert("github_url".to_string(), val.clone().into());
}
if let Some(val) = inputs.get("license_url") {
body.insert("license_url".to_string(), val.clone().into());
}
if !body.is_empty() {
builder = builder.json(&serde_json::Value::Object(body));
}
let mut output = HashMap::new();
match builder.send().await {
Ok(resp) => {
let status = resp.status().as_u16();
let headers: HashMap<String, String> = resp
.headers()
.iter()
.filter_map(|(k, v)| v.to_str().ok().map(|val| (k.to_string(), val.to_string())))
.collect();
let body_text = resp.text().await.unwrap_or_default();
let body_value: Value =
serde_json::from_str(&body_text).unwrap_or(Value::String(body_text));
output.insert(
"response".to_string(),
Message::object(EncodableValue::from(json!({
"status": status,
"headers": headers,
"body": body_value,
}))),
);
}
Err(e) => {
output.insert(
"error".to_string(),
Message::Error(format!("POST /models failed: {}", e).into()),
);
}
}
Ok(output)
}
#[actor(
ReplicateReadModelsActor,
inports::<100>(model_owner, model_name),
outports::<50>(response, error),
state(MemoryState)
)]
pub async fn replicate_read_models(
context: ActorContext,
) -> Result<HashMap<String, Message>, Error> {
let inputs = context.get_payload();
let actor_config = context.get_config();
let mut endpoint = "/models/{model_owner}/{model_name}".to_string();
if let Some(val) = inputs.get("model_owner") {
endpoint = endpoint.replace("{{model_owner}}", &super::message_to_str(val));
}
if let Some(val) = inputs.get("model_name") {
endpoint = endpoint.replace("{{model_name}}", &super::message_to_str(val));
}
let url = format!("{}{}", BASE_URL.trim_end_matches('/'), endpoint);
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.build()?;
let mut builder = client.get(&url);
builder = builder.header("Content-Type", "application/json");
builder = apply_auth(actor_config, builder)?;
let mut output = HashMap::new();
match builder.send().await {
Ok(resp) => {
let status = resp.status().as_u16();
let headers: HashMap<String, String> = resp
.headers()
.iter()
.filter_map(|(k, v)| v.to_str().ok().map(|val| (k.to_string(), val.to_string())))
.collect();
let body_text = resp.text().await.unwrap_or_default();
let body_value: Value =
serde_json::from_str(&body_text).unwrap_or(Value::String(body_text));
output.insert(
"response".to_string(),
Message::object(EncodableValue::from(json!({
"status": status,
"headers": headers,
"body": body_value,
}))),
);
}
Err(e) => {
output.insert(
"error".to_string(),
Message::Error(
format!("GET /models/{{model_owner}}/{{model_name}} failed: {}", e).into(),
),
);
}
}
Ok(output)
}
#[actor(
ReplicateUpdateModelsActor,
inports::<100>(model_owner, model_name, weights_url, description, license_url, github_url, paper_url, readme),
outports::<50>(response, error),
state(MemoryState)
)]
pub async fn replicate_update_models(
context: ActorContext,
) -> Result<HashMap<String, Message>, Error> {
let inputs = context.get_payload();
let actor_config = context.get_config();
let mut endpoint = "/models/{model_owner}/{model_name}".to_string();
if let Some(val) = inputs.get("model_owner") {
endpoint = endpoint.replace("{{model_owner}}", &super::message_to_str(val));
}
if let Some(val) = inputs.get("model_name") {
endpoint = endpoint.replace("{{model_name}}", &super::message_to_str(val));
}
let url = format!("{}{}", BASE_URL.trim_end_matches('/'), endpoint);
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.build()?;
let mut builder = client.patch(&url);
builder = builder.header("Content-Type", "application/json");
builder = apply_auth(actor_config, builder)?;
let mut body = serde_json::Map::new();
if let Some(val) = inputs.get("weights_url") {
body.insert("weights_url".to_string(), val.clone().into());
}
if let Some(val) = inputs.get("description") {
body.insert("description".to_string(), val.clone().into());
}
if let Some(val) = inputs.get("license_url") {
body.insert("license_url".to_string(), val.clone().into());
}
if let Some(val) = inputs.get("github_url") {
body.insert("github_url".to_string(), val.clone().into());
}
if let Some(val) = inputs.get("paper_url") {
body.insert("paper_url".to_string(), val.clone().into());
}
if let Some(val) = inputs.get("readme") {
body.insert("readme".to_string(), val.clone().into());
}
if !body.is_empty() {
builder = builder.json(&serde_json::Value::Object(body));
}
let mut output = HashMap::new();
match builder.send().await {
Ok(resp) => {
let status = resp.status().as_u16();
let headers: HashMap<String, String> = resp
.headers()
.iter()
.filter_map(|(k, v)| v.to_str().ok().map(|val| (k.to_string(), val.to_string())))
.collect();
let body_text = resp.text().await.unwrap_or_default();
let body_value: Value =
serde_json::from_str(&body_text).unwrap_or(Value::String(body_text));
output.insert(
"response".to_string(),
Message::object(EncodableValue::from(json!({
"status": status,
"headers": headers,
"body": body_value,
}))),
);
}
Err(e) => {
output.insert(
"error".to_string(),
Message::Error(
format!("PATCH /models/{{model_owner}}/{{model_name}} failed: {}", e).into(),
),
);
}
}
Ok(output)
}
#[actor(
ReplicateDeleteModelsActor,
inports::<100>(model_owner, model_name),
outports::<50>(response, error),
state(MemoryState)
)]
pub async fn replicate_delete_models(
context: ActorContext,
) -> Result<HashMap<String, Message>, Error> {
let inputs = context.get_payload();
let actor_config = context.get_config();
let mut endpoint = "/models/{model_owner}/{model_name}".to_string();
if let Some(val) = inputs.get("model_owner") {
endpoint = endpoint.replace("{{model_owner}}", &super::message_to_str(val));
}
if let Some(val) = inputs.get("model_name") {
endpoint = endpoint.replace("{{model_name}}", &super::message_to_str(val));
}
let url = format!("{}{}", BASE_URL.trim_end_matches('/'), endpoint);
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.build()?;
let mut builder = client.delete(&url);
builder = builder.header("Content-Type", "application/json");
builder = apply_auth(actor_config, builder)?;
let mut output = HashMap::new();
match builder.send().await {
Ok(resp) => {
let status = resp.status().as_u16();
let headers: HashMap<String, String> = resp
.headers()
.iter()
.filter_map(|(k, v)| v.to_str().ok().map(|val| (k.to_string(), val.to_string())))
.collect();
let body_text = resp.text().await.unwrap_or_default();
let body_value: Value =
serde_json::from_str(&body_text).unwrap_or(Value::String(body_text));
output.insert(
"response".to_string(),
Message::object(EncodableValue::from(json!({
"status": status,
"headers": headers,
"body": body_value,
}))),
);
}
Err(e) => {
output.insert(
"error".to_string(),
Message::Error(
format!(
"DELETE /models/{{model_owner}}/{{model_name}} failed: {}",
e
)
.into(),
),
);
}
}
Ok(output)
}
#[actor(
ReplicateReadExamplesActor,
inports::<100>(model_owner, model_name),
outports::<50>(response, error),
state(MemoryState)
)]
pub async fn replicate_read_examples(
context: ActorContext,
) -> Result<HashMap<String, Message>, Error> {
let inputs = context.get_payload();
let actor_config = context.get_config();
let mut endpoint = "/models/{model_owner}/{model_name}/examples".to_string();
if let Some(val) = inputs.get("model_owner") {
endpoint = endpoint.replace("{{model_owner}}", &super::message_to_str(val));
}
if let Some(val) = inputs.get("model_name") {
endpoint = endpoint.replace("{{model_name}}", &super::message_to_str(val));
}
let url = format!("{}{}", BASE_URL.trim_end_matches('/'), endpoint);
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.build()?;
let mut builder = client.get(&url);
builder = builder.header("Content-Type", "application/json");
builder = apply_auth(actor_config, builder)?;
let mut output = HashMap::new();
match builder.send().await {
Ok(resp) => {
let status = resp.status().as_u16();
let headers: HashMap<String, String> = resp
.headers()
.iter()
.filter_map(|(k, v)| v.to_str().ok().map(|val| (k.to_string(), val.to_string())))
.collect();
let body_text = resp.text().await.unwrap_or_default();
let body_value: Value =
serde_json::from_str(&body_text).unwrap_or(Value::String(body_text));
output.insert(
"response".to_string(),
Message::object(EncodableValue::from(json!({
"status": status,
"headers": headers,
"body": body_value,
}))),
);
}
Err(e) => {
output.insert(
"error".to_string(),
Message::Error(
format!(
"GET /models/{{model_owner}}/{{model_name}}/examples failed: {}",
e
)
.into(),
),
);
}
}
Ok(output)
}
#[actor(
ReplicateReadReadmeActor,
inports::<100>(model_owner, model_name),
outports::<50>(response, error),
state(MemoryState)
)]
pub async fn replicate_read_readme(
context: ActorContext,
) -> Result<HashMap<String, Message>, Error> {
let inputs = context.get_payload();
let actor_config = context.get_config();
let mut endpoint = "/models/{model_owner}/{model_name}/readme".to_string();
if let Some(val) = inputs.get("model_owner") {
endpoint = endpoint.replace("{{model_owner}}", &super::message_to_str(val));
}
if let Some(val) = inputs.get("model_name") {
endpoint = endpoint.replace("{{model_name}}", &super::message_to_str(val));
}
let url = format!("{}{}", BASE_URL.trim_end_matches('/'), endpoint);
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.build()?;
let mut builder = client.get(&url);
builder = builder.header("Content-Type", "application/json");
builder = apply_auth(actor_config, builder)?;
let mut output = HashMap::new();
match builder.send().await {
Ok(resp) => {
let status = resp.status().as_u16();
let headers: HashMap<String, String> = resp
.headers()
.iter()
.filter_map(|(k, v)| v.to_str().ok().map(|val| (k.to_string(), val.to_string())))
.collect();
let body_text = resp.text().await.unwrap_or_default();
let body_value: Value =
serde_json::from_str(&body_text).unwrap_or(Value::String(body_text));
output.insert(
"response".to_string(),
Message::object(EncodableValue::from(json!({
"status": status,
"headers": headers,
"body": body_value,
}))),
);
}
Err(e) => {
output.insert(
"error".to_string(),
Message::Error(
format!(
"GET /models/{{model_owner}}/{{model_name}}/readme failed: {}",
e
)
.into(),
),
);
}
}
Ok(output)
}
#[actor(
ReplicateReadVersionsActor,
inports::<100>(model_owner, model_name),
outports::<50>(response, error),
state(MemoryState)
)]
pub async fn replicate_read_versions(
context: ActorContext,
) -> Result<HashMap<String, Message>, Error> {
let inputs = context.get_payload();
let actor_config = context.get_config();
let mut endpoint = "/models/{model_owner}/{model_name}/versions".to_string();
if let Some(val) = inputs.get("model_owner") {
endpoint = endpoint.replace("{{model_owner}}", &super::message_to_str(val));
}
if let Some(val) = inputs.get("model_name") {
endpoint = endpoint.replace("{{model_name}}", &super::message_to_str(val));
}
let url = format!("{}{}", BASE_URL.trim_end_matches('/'), endpoint);
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.build()?;
let mut builder = client.get(&url);
builder = builder.header("Content-Type", "application/json");
builder = apply_auth(actor_config, builder)?;
let mut output = HashMap::new();
match builder.send().await {
Ok(resp) => {
let status = resp.status().as_u16();
let headers: HashMap<String, String> = resp
.headers()
.iter()
.filter_map(|(k, v)| v.to_str().ok().map(|val| (k.to_string(), val.to_string())))
.collect();
let body_text = resp.text().await.unwrap_or_default();
let body_value: Value =
serde_json::from_str(&body_text).unwrap_or(Value::String(body_text));
output.insert(
"response".to_string(),
Message::object(EncodableValue::from(json!({
"status": status,
"headers": headers,
"body": body_value,
}))),
);
}
Err(e) => {
output.insert(
"error".to_string(),
Message::Error(
format!(
"GET /models/{{model_owner}}/{{model_name}}/versions failed: {}",
e
)
.into(),
),
);
}
}
Ok(output)
}
#[actor(
ReplicateDeleteVersionsActor,
inports::<100>(model_owner, model_name, version_id),
outports::<50>(response, error),
state(MemoryState)
)]
pub async fn replicate_delete_versions(
context: ActorContext,
) -> Result<HashMap<String, Message>, Error> {
let inputs = context.get_payload();
let actor_config = context.get_config();
let mut endpoint = "/models/{model_owner}/{model_name}/versions/{version_id}".to_string();
if let Some(val) = inputs.get("model_owner") {
endpoint = endpoint.replace("{{model_owner}}", &super::message_to_str(val));
}
if let Some(val) = inputs.get("model_name") {
endpoint = endpoint.replace("{{model_name}}", &super::message_to_str(val));
}
if let Some(val) = inputs.get("version_id") {
endpoint = endpoint.replace("{{version_id}}", &super::message_to_str(val));
}
let url = format!("{}{}", BASE_URL.trim_end_matches('/'), endpoint);
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.build()?;
let mut builder = client.delete(&url);
builder = builder.header("Content-Type", "application/json");
builder = apply_auth(actor_config, builder)?;
let mut output = HashMap::new();
match builder.send().await {
Ok(resp) => {
let status = resp.status().as_u16();
let headers: HashMap<String, String> = resp
.headers()
.iter()
.filter_map(|(k, v)| v.to_str().ok().map(|val| (k.to_string(), val.to_string())))
.collect();
let body_text = resp.text().await.unwrap_or_default();
let body_value: Value =
serde_json::from_str(&body_text).unwrap_or(Value::String(body_text));
output.insert(
"response".to_string(),
Message::object(EncodableValue::from(json!({
"status": status,
"headers": headers,
"body": body_value,
}))),
);
}
Err(e) => {
output.insert("error".to_string(), Message::Error(format!("DELETE /models/{{model_owner}}/{{model_name}}/versions/{{version_id}} failed: {}", e).into()));
}
}
Ok(output)
}
#[actor(
ReplicateCreateTrainingsActor,
inports::<100>(model_owner, model_name, version_id, destination, input, webhook_events_filter, webhook),
outports::<50>(response, error),
state(MemoryState)
)]
pub async fn replicate_create_trainings(
context: ActorContext,
) -> Result<HashMap<String, Message>, Error> {
let inputs = context.get_payload();
let actor_config = context.get_config();
let mut endpoint =
"/models/{model_owner}/{model_name}/versions/{version_id}/trainings".to_string();
if let Some(val) = inputs.get("model_owner") {
endpoint = endpoint.replace("{{model_owner}}", &super::message_to_str(val));
}
if let Some(val) = inputs.get("model_name") {
endpoint = endpoint.replace("{{model_name}}", &super::message_to_str(val));
}
if let Some(val) = inputs.get("version_id") {
endpoint = endpoint.replace("{{version_id}}", &super::message_to_str(val));
}
let url = format!("{}{}", BASE_URL.trim_end_matches('/'), endpoint);
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.build()?;
let mut builder = client.post(&url);
builder = builder.header("Content-Type", "application/json");
builder = apply_auth(actor_config, builder)?;
let mut body = serde_json::Map::new();
if let Some(val) = inputs.get("destination") {
body.insert("destination".to_string(), val.clone().into());
}
if let Some(val) = inputs.get("input") {
body.insert("input".to_string(), val.clone().into());
}
if let Some(val) = inputs.get("webhook_events_filter") {
body.insert("webhook_events_filter".to_string(), val.clone().into());
}
if let Some(val) = inputs.get("webhook") {
body.insert("webhook".to_string(), val.clone().into());
}
if !body.is_empty() {
builder = builder.json(&serde_json::Value::Object(body));
}
let mut output = HashMap::new();
match builder.send().await {
Ok(resp) => {
let status = resp.status().as_u16();
let headers: HashMap<String, String> = resp
.headers()
.iter()
.filter_map(|(k, v)| v.to_str().ok().map(|val| (k.to_string(), val.to_string())))
.collect();
let body_text = resp.text().await.unwrap_or_default();
let body_value: Value =
serde_json::from_str(&body_text).unwrap_or(Value::String(body_text));
output.insert(
"response".to_string(),
Message::object(EncodableValue::from(json!({
"status": status,
"headers": headers,
"body": body_value,
}))),
);
}
Err(e) => {
output.insert("error".to_string(), Message::Error(format!("POST /models/{{model_owner}}/{{model_name}}/versions/{{version_id}}/trainings failed: {}", e).into()));
}
}
Ok(output)
}
#[actor(
ReplicateListPredictionsActor,
inports::<100>(created_after, created_before, source),
outports::<50>(response, error),
state(MemoryState)
)]
pub async fn replicate_list_predictions(
context: ActorContext,
) -> Result<HashMap<String, Message>, Error> {
let inputs = context.get_payload();
let actor_config = context.get_config();
let endpoint = "/predictions".to_string();
let url = format!("{}{}", BASE_URL.trim_end_matches('/'), endpoint);
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.build()?;
let mut builder = client.get(&url);
builder = builder.header("Content-Type", "application/json");
builder = apply_auth(actor_config, builder)?;
let mut query_pairs: Vec<(&str, String)> = Vec::new();
if let Some(val) = inputs.get("created_after") {
query_pairs.push(("created_after", super::message_to_str(val)));
}
if let Some(val) = inputs.get("created_before") {
query_pairs.push(("created_before", super::message_to_str(val)));
}
if let Some(val) = inputs.get("source") {
query_pairs.push(("source", super::message_to_str(val)));
}
if !query_pairs.is_empty() {
builder = builder.query(&query_pairs);
}
let mut output = HashMap::new();
match builder.send().await {
Ok(resp) => {
let status = resp.status().as_u16();
let headers: HashMap<String, String> = resp
.headers()
.iter()
.filter_map(|(k, v)| v.to_str().ok().map(|val| (k.to_string(), val.to_string())))
.collect();
let body_text = resp.text().await.unwrap_or_default();
let body_value: Value =
serde_json::from_str(&body_text).unwrap_or(Value::String(body_text));
output.insert(
"response".to_string(),
Message::object(EncodableValue::from(json!({
"status": status,
"headers": headers,
"body": body_value,
}))),
);
}
Err(e) => {
output.insert(
"error".to_string(),
Message::Error(format!("GET /predictions failed: {}", e).into()),
);
}
}
Ok(output)
}
#[actor(
ReplicateReadPredictionsActor,
inports::<100>(prediction_id),
outports::<50>(response, error),
state(MemoryState)
)]
pub async fn replicate_read_predictions(
context: ActorContext,
) -> Result<HashMap<String, Message>, Error> {
let inputs = context.get_payload();
let actor_config = context.get_config();
let mut endpoint = "/predictions/{prediction_id}".to_string();
if let Some(val) = inputs.get("prediction_id") {
endpoint = endpoint.replace("{{prediction_id}}", &super::message_to_str(val));
}
let url = format!("{}{}", BASE_URL.trim_end_matches('/'), endpoint);
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.build()?;
let mut builder = client.get(&url);
builder = builder.header("Content-Type", "application/json");
builder = apply_auth(actor_config, builder)?;
let mut output = HashMap::new();
match builder.send().await {
Ok(resp) => {
let status = resp.status().as_u16();
let headers: HashMap<String, String> = resp
.headers()
.iter()
.filter_map(|(k, v)| v.to_str().ok().map(|val| (k.to_string(), val.to_string())))
.collect();
let body_text = resp.text().await.unwrap_or_default();
let body_value: Value =
serde_json::from_str(&body_text).unwrap_or(Value::String(body_text));
output.insert(
"response".to_string(),
Message::object(EncodableValue::from(json!({
"status": status,
"headers": headers,
"body": body_value,
}))),
);
}
Err(e) => {
output.insert(
"error".to_string(),
Message::Error(format!("GET /predictions/{{prediction_id}} failed: {}", e).into()),
);
}
}
Ok(output)
}
#[actor(
ReplicateCancelCancelActor,
inports::<100>(prediction_id),
outports::<50>(response, error),
state(MemoryState)
)]
pub async fn replicate_cancel_cancel(
context: ActorContext,
) -> Result<HashMap<String, Message>, Error> {
let inputs = context.get_payload();
let actor_config = context.get_config();
let mut endpoint = "/predictions/{prediction_id}/cancel".to_string();
if let Some(val) = inputs.get("prediction_id") {
endpoint = endpoint.replace("{{prediction_id}}", &super::message_to_str(val));
}
let url = format!("{}{}", BASE_URL.trim_end_matches('/'), endpoint);
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.build()?;
let mut builder = client.post(&url);
builder = builder.header("Content-Type", "application/json");
builder = apply_auth(actor_config, builder)?;
let mut output = HashMap::new();
match builder.send().await {
Ok(resp) => {
let status = resp.status().as_u16();
let headers: HashMap<String, String> = resp
.headers()
.iter()
.filter_map(|(k, v)| v.to_str().ok().map(|val| (k.to_string(), val.to_string())))
.collect();
let body_text = resp.text().await.unwrap_or_default();
let body_value: Value =
serde_json::from_str(&body_text).unwrap_or(Value::String(body_text));
output.insert(
"response".to_string(),
Message::object(EncodableValue::from(json!({
"status": status,
"headers": headers,
"body": body_value,
}))),
);
}
Err(e) => {
output.insert(
"error".to_string(),
Message::Error(
format!("POST /predictions/{{prediction_id}}/cancel failed: {}", e).into(),
),
);
}
}
Ok(output)
}
#[actor(
ReplicateSearchSearchActor,
inports::<100>(query, limit),
outports::<50>(response, error),
state(MemoryState)
)]
pub async fn replicate_search_search(
context: ActorContext,
) -> Result<HashMap<String, Message>, Error> {
let inputs = context.get_payload();
let actor_config = context.get_config();
let endpoint = "/search".to_string();
let url = format!("{}{}", BASE_URL.trim_end_matches('/'), endpoint);
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.build()?;
let mut builder = client.get(&url);
builder = builder.header("Content-Type", "application/json");
builder = apply_auth(actor_config, builder)?;
let mut query_pairs: Vec<(&str, String)> = Vec::new();
if let Some(val) = inputs.get("query") {
query_pairs.push(("query", super::message_to_str(val)));
}
if let Some(val) = inputs.get("limit") {
query_pairs.push(("limit", super::message_to_str(val)));
}
if !query_pairs.is_empty() {
builder = builder.query(&query_pairs);
}
let mut output = HashMap::new();
match builder.send().await {
Ok(resp) => {
let status = resp.status().as_u16();
let headers: HashMap<String, String> = resp
.headers()
.iter()
.filter_map(|(k, v)| v.to_str().ok().map(|val| (k.to_string(), val.to_string())))
.collect();
let body_text = resp.text().await.unwrap_or_default();
let body_value: Value =
serde_json::from_str(&body_text).unwrap_or(Value::String(body_text));
output.insert(
"response".to_string(),
Message::object(EncodableValue::from(json!({
"status": status,
"headers": headers,
"body": body_value,
}))),
);
}
Err(e) => {
output.insert(
"error".to_string(),
Message::Error(format!("GET /search failed: {}", e).into()),
);
}
}
Ok(output)
}
#[actor(
ReplicateListTrainingsActor,
inports::<100>(trigger),
outports::<50>(response, error),
state(MemoryState)
)]
pub async fn replicate_list_trainings(
context: ActorContext,
) -> Result<HashMap<String, Message>, Error> {
let inputs = context.get_payload();
let actor_config = context.get_config();
let endpoint = "/trainings".to_string();
let url = format!("{}{}", BASE_URL.trim_end_matches('/'), endpoint);
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.build()?;
let mut builder = client.get(&url);
builder = builder.header("Content-Type", "application/json");
builder = apply_auth(actor_config, builder)?;
let mut output = HashMap::new();
match builder.send().await {
Ok(resp) => {
let status = resp.status().as_u16();
let headers: HashMap<String, String> = resp
.headers()
.iter()
.filter_map(|(k, v)| v.to_str().ok().map(|val| (k.to_string(), val.to_string())))
.collect();
let body_text = resp.text().await.unwrap_or_default();
let body_value: Value =
serde_json::from_str(&body_text).unwrap_or(Value::String(body_text));
output.insert(
"response".to_string(),
Message::object(EncodableValue::from(json!({
"status": status,
"headers": headers,
"body": body_value,
}))),
);
}
Err(e) => {
output.insert(
"error".to_string(),
Message::Error(format!("GET /trainings failed: {}", e).into()),
);
}
}
Ok(output)
}
#[actor(
ReplicateReadTrainingsActor,
inports::<100>(training_id),
outports::<50>(response, error),
state(MemoryState)
)]
pub async fn replicate_read_trainings(
context: ActorContext,
) -> Result<HashMap<String, Message>, Error> {
let inputs = context.get_payload();
let actor_config = context.get_config();
let mut endpoint = "/trainings/{training_id}".to_string();
if let Some(val) = inputs.get("training_id") {
endpoint = endpoint.replace("{{training_id}}", &super::message_to_str(val));
}
let url = format!("{}{}", BASE_URL.trim_end_matches('/'), endpoint);
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.build()?;
let mut builder = client.get(&url);
builder = builder.header("Content-Type", "application/json");
builder = apply_auth(actor_config, builder)?;
let mut output = HashMap::new();
match builder.send().await {
Ok(resp) => {
let status = resp.status().as_u16();
let headers: HashMap<String, String> = resp
.headers()
.iter()
.filter_map(|(k, v)| v.to_str().ok().map(|val| (k.to_string(), val.to_string())))
.collect();
let body_text = resp.text().await.unwrap_or_default();
let body_value: Value =
serde_json::from_str(&body_text).unwrap_or(Value::String(body_text));
output.insert(
"response".to_string(),
Message::object(EncodableValue::from(json!({
"status": status,
"headers": headers,
"body": body_value,
}))),
);
}
Err(e) => {
output.insert(
"error".to_string(),
Message::Error(format!("GET /trainings/{{training_id}} failed: {}", e).into()),
);
}
}
Ok(output)
}
#[actor(
ReplicateListSecretActor,
inports::<100>(trigger),
outports::<50>(response, error),
state(MemoryState)
)]
pub async fn replicate_list_secret(
context: ActorContext,
) -> Result<HashMap<String, Message>, Error> {
let inputs = context.get_payload();
let actor_config = context.get_config();
let endpoint = "/webhooks/default/secret".to_string();
let url = format!("{}{}", BASE_URL.trim_end_matches('/'), endpoint);
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.build()?;
let mut builder = client.get(&url);
builder = builder.header("Content-Type", "application/json");
builder = apply_auth(actor_config, builder)?;
let mut output = HashMap::new();
match builder.send().await {
Ok(resp) => {
let status = resp.status().as_u16();
let headers: HashMap<String, String> = resp
.headers()
.iter()
.filter_map(|(k, v)| v.to_str().ok().map(|val| (k.to_string(), val.to_string())))
.collect();
let body_text = resp.text().await.unwrap_or_default();
let body_value: Value =
serde_json::from_str(&body_text).unwrap_or(Value::String(body_text));
output.insert(
"response".to_string(),
Message::object(EncodableValue::from(json!({
"status": status,
"headers": headers,
"body": body_value,
}))),
);
}
Err(e) => {
output.insert(
"error".to_string(),
Message::Error(format!("GET /webhooks/default/secret failed: {}", e).into()),
);
}
}
Ok(output)
}