langchain_rust/tools/wolfram/
wolfram.rs1use async_trait::async_trait;
2use serde_json::Value;
3
4use crate::tools::Tool;
5use std::error::Error;
6
7#[derive(Debug, serde::Serialize, serde::Deserialize)]
8struct WolframError {
9 code: String,
10 msg: String,
11}
12
13#[derive(Debug, serde::Serialize, serde::Deserialize)]
14#[serde(untagged)]
15enum WolframErrorStatus {
16 Error(WolframError),
17 NoError(bool),
18}
19
20#[derive(Debug, serde::Serialize, serde::Deserialize)]
21struct WolframResponse {
22 queryresult: WolframResponseContent,
23}
24
25#[derive(Debug, serde::Serialize, serde::Deserialize)]
26struct WolframResponseContent {
27 success: bool,
28 error: WolframErrorStatus,
29 pods: Option<Vec<Pod>>,
30}
31
32#[derive(Debug, serde::Serialize, serde::Deserialize)]
33struct Pod {
34 title: String,
35 subpods: Vec<Subpod>,
36}
37
38impl From<Pod> for String {
39 fn from(pod: Pod) -> String {
40 let subpods_str: Vec<String> = pod
41 .subpods
42 .into_iter()
43 .map(String::from)
44 .filter(|s| !s.is_empty())
45 .collect();
46
47 if subpods_str.is_empty() {
48 return String::from("");
49 }
50
51 format!(
52 "{{\"title\": {},\"subpods\": [{}]}}",
53 pod.title,
54 subpods_str.join(",")
55 )
56 }
57}
58
59#[derive(Debug, serde::Serialize, serde::Deserialize)]
60struct Subpod {
61 title: String,
62 plaintext: String,
63}
64
65impl From<Subpod> for String {
66 fn from(subpod: Subpod) -> String {
67 if subpod.plaintext.is_empty() {
68 return String::from("");
69 }
70
71 format!(
72 "{{\"title\": \"{}\",\"plaintext\": \"{}\"}}",
73 subpod.title,
74 subpod.plaintext.replace("\n", " // ")
75 )
76 }
77}
78
79pub struct Wolfram {
81 app_id: String,
82 exclude_pods: Vec<String>,
83 client: reqwest::Client,
84}
85
86impl Wolfram {
87 pub fn new(app_id: String) -> Self {
88 Self {
89 app_id,
90 exclude_pods: Vec::new(),
91 client: reqwest::Client::new(),
92 }
93 }
94
95 pub fn with_excludes<S: AsRef<str>>(mut self, exclude_pods: &[S]) -> Self {
96 self.exclude_pods = exclude_pods.iter().map(|s| s.as_ref().to_owned()).collect();
97 self
98 }
99
100 pub fn with_app_id<S: AsRef<str>>(mut self, app_id: S) -> Self {
101 self.app_id = app_id.as_ref().to_owned();
102 self
103 }
104}
105
106impl Default for Wolfram {
107 fn default() -> Wolfram {
108 Wolfram {
109 app_id: std::env::var("WOLFRAM_APP_ID").unwrap_or_default(),
110 exclude_pods: Vec::new(),
111 client: reqwest::Client::new(),
112 }
113 }
114}
115
116#[async_trait]
117impl Tool for Wolfram {
118 fn name(&self) -> String {
119 String::from("Wolfram")
120 }
121
122 fn description(&self) -> String {
123 String::from(
124 "Wolfram Solver leverages the Wolfram Alpha computational engine
125 to solve complex queries. Input should be a valid mathematical
126 expression or query formulated in a way that Wolfram Alpha can
127 interpret.",
128 )
129 }
130 async fn run(&self, input: Value) -> Result<String, Box<dyn Error>> {
131 let input = input.as_str().ok_or("Invalid input")?;
132 let mut url = format!(
133 "https://api.wolframalpha.com/v2/query?appid={}&input={}&output=JSON&format=plaintext&podstate=Result__Step-by-step+solution",
134 &self.app_id,
135 urlencoding::encode(input)
136 );
137
138 if !self.exclude_pods.is_empty() {
139 url += &format!("&excludepodid={}", self.exclude_pods.join(","));
140 }
141
142 let response: WolframResponse = self.client.get(&url).send().await?.json().await?;
143
144 if let WolframErrorStatus::Error(error) = response.queryresult.error {
145 return Err(Box::new(std::io::Error::new(
146 std::io::ErrorKind::Other,
147 format!("Wolfram Error {}: {}", error.code, error.msg),
148 )));
149 } else if !response.queryresult.success {
150 return Err(Box::new(std::io::Error::new(
151 std::io::ErrorKind::Other,
152 "Wolfram Error invalid query input: The query requested can not be processed by Wolfram".to_string(),
153 )));
154 }
155
156 let pods_str: Vec<String> = response
157 .queryresult
158 .pods
159 .unwrap_or_default()
160 .into_iter()
161 .map(String::from)
162 .filter(|s| !s.is_empty())
163 .collect();
164
165 Ok(format!("{{\"pods\": [{}]}}", pods_str.join(",")))
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172
173 #[tokio::test]
174 #[ignore]
175 async fn test_wolfram() {
176 let wolfram = Wolfram::default().with_excludes(&["Plot"]);
177 let input = "Solve x^2 - 2x + 1 = 0";
178 let result = wolfram.call(input).await;
179
180 assert!(result.is_ok());
181 println!("{}", result.unwrap());
182 }
183}