taskforceai-sdk 0.1.0

Rust SDK for TaskForceAI
Documentation
use crate::error::TaskForceAIError;
use crate::types::{
    SubmitTaskResponse, TaskForceAIOptions, TaskStatus, TaskStatusValue, TaskSubmissionOptions,
};
use std::time::Duration;
use tokio::time::sleep;

pub const DEFAULT_BASE_URL: &str = "https://taskforceai.chat/api/developer";
pub const DEFAULT_TIMEOUT_SECS: u64 = 30;
pub const DEFAULT_POLL_INTERVAL_MS: u64 = 1000;
pub const DEFAULT_MAX_POLL_ATTEMPTS: u32 = 60;

pub struct TaskForceAI {
    pub(crate) api_key: String,
    pub(crate) base_url: String,
    #[allow(dead_code)]
    pub(crate) timeout: Duration,
    pub(crate) mock_mode: bool,
    pub(crate) client: reqwest::Client,
}

impl TaskForceAI {
    pub fn new(options: TaskForceAIOptions) -> Result<Self, TaskForceAIError> {
        let mock_mode = options.mock_mode.unwrap_or(false);
        let api_key = options.api_key.unwrap_or_default();

        if !mock_mode && api_key.is_empty() {
            return Err(TaskForceAIError::MissingApiKey);
        }

        let base_url = options
            .base_url
            .unwrap_or_else(|| DEFAULT_BASE_URL.to_string())
            .trim_end_matches('/')
            .to_string();

        let timeout = Duration::from_secs(options.timeout.unwrap_or(DEFAULT_TIMEOUT_SECS));

        let client = reqwest::Client::builder().timeout(timeout).build()?;

        Ok(Self {
            api_key,
            base_url,
            timeout,
            mock_mode,
            client,
        })
    }

    pub(crate) async fn request<T>(
        &self,
        method: reqwest::Method,
        path: &str,
        body: Option<serde_json::Value>,
    ) -> Result<T, TaskForceAIError>
    where
        T: serde::de::DeserializeOwned,
    {
        if self.mock_mode {
            return self.mock_response(path, &method);
        }

        let url = format!("{}{}", self.base_url, path);
        let mut request = self.client.request(method, &url);

        if !self.api_key.is_empty() {
            request = request.bearer_auth(&self.api_key);
        }

        request = request.header("X-SDK-Language", "rust");

        if let Some(b) = body {
            request = request.json(&b);
        }

        let response = request.send().await?;
        let status = response.status();

        if !status.is_success() {
            let message = response
                .text()
                .await
                .unwrap_or_else(|_| "Failed to read error message from response body".to_string());
            return Err(TaskForceAIError::Api { status, message });
        }

        Ok(response.json().await?)
    }

    fn mock_response<T>(&self, path: &str, method: &reqwest::Method) -> Result<T, TaskForceAIError>
    where
        T: serde::de::DeserializeOwned,
    {
        let val = if method == reqwest::Method::POST && path == "/run" {
            serde_json::json!({ "taskId": "mock-task-123" })
        } else if path.starts_with("/status/") {
            serde_json::json!({
                "taskId": "mock-task-123",
                "status": "completed",
                "result": "This is a mock response. Configure your API key to get real results."
            })
        } else {
            serde_json::json!({ "status": "ok" })
        };

        Ok(serde_json::from_value(val)?)
    }

    pub async fn submit_task(
        &self,
        prompt: &str,
        options: Option<TaskSubmissionOptions>,
    ) -> Result<String, TaskForceAIError> {
        if prompt.trim().is_empty() {
            return Err(TaskForceAIError::EmptyPrompt);
        }

        let mut body = serde_json::json!({ "prompt": prompt });
        if let Some(opts) = options {
            if let Some(obj) = body.as_object_mut() {
                obj.insert("options".to_string(), serde_json::to_value(opts)?);
            }
        }

        let response: SubmitTaskResponse = self
            .request(reqwest::Method::POST, "/run", Some(body))
            .await?;
        Ok(response.task_id)
    }

    pub async fn get_task_status(&self, task_id: &str) -> Result<TaskStatus, TaskForceAIError> {
        if task_id.trim().is_empty() {
            return Err(TaskForceAIError::EmptyTaskId);
        }
        self.request(reqwest::Method::GET, &format!("/status/{}", task_id), None)
            .await
    }

    pub async fn wait_for_completion(
        &self,
        task_id: &str,
        poll_interval: Option<Duration>,
        max_attempts: Option<u32>,
    ) -> Result<TaskStatus, TaskForceAIError> {
        let interval = poll_interval.unwrap_or(Duration::from_millis(DEFAULT_POLL_INTERVAL_MS));
        let max = max_attempts.unwrap_or(DEFAULT_MAX_POLL_ATTEMPTS);

        for _ in 0..max {
            let status = self.get_task_status(task_id).await?;
            match status.status {
                TaskStatusValue::Completed => return Ok(status),
                TaskStatusValue::Failed => {
                    return Err(TaskForceAIError::TaskFailed(
                        status.error.unwrap_or_else(|| "Unknown error".to_string()),
                    ))
                }
                TaskStatusValue::Processing => (),
            }
            sleep(interval).await;
        }

        Err(TaskForceAIError::Timeout)
    }

    pub async fn run_task(
        &self,
        prompt: &str,
        options: Option<TaskSubmissionOptions>,
        poll_interval: Option<Duration>,
        max_attempts: Option<u32>,
    ) -> Result<TaskStatus, TaskForceAIError> {
        let task_id = self.submit_task(prompt, options).await?;
        self.wait_for_completion(&task_id, poll_interval, max_attempts)
            .await
    }
}